Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions broker/patron_request/api/api-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ func (a *PatronRequestApiHandler) GetPatronRequests(w http.ResponseWriter, r *ht
return
}
cqlStr := cql.String()
prs, count, err := a.prRepo.ListPatronRequestsSearchView(ctx, pr_db.ListPatronRequestsParams{Limit: limit, Offset: offset}, &cqlStr)
pgcql, err := pr_db.ParsePatronRequestsCql(cqlStr)
if err != nil {
api.AddBadRequestError(ctx, w, err)
return
}

prs, count, err := a.prRepo.ListPatronRequestsSearchView(ctx, pr_db.ListPatronRequestsParams{Limit: limit, Offset: offset}, pgcql)
if err != nil && !errors.Is(err, pgx.ErrNoRows) { //DB error
api.AddInternalError(ctx, w, err)
return
Expand All @@ -188,7 +194,7 @@ func (a *PatronRequestApiHandler) GetPatronRequests(w http.ResponseWriter, r *ht
resp.About = toProAboutWithFacets(api.CollectAboutData(count, offset, limit, r))
var facets []pr_db.Facet
if params.Facets != nil {
facets, err = a.prRepo.GetPatronRequestsFacets(ctx, *params.Facets, cqlStr)
facets, err = a.prRepo.GetPatronRequestsFacets(ctx, *params.Facets, pgcql)
}
if err != nil {
if errors.Is(err, pr_db.ErrUnsupportedFacet) {
Expand Down
39 changes: 20 additions & 19 deletions broker/patron_request/api/api-handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/indexdata/cql-go/pgcql"
"github.com/indexdata/crosslink/broker/common"
"github.com/indexdata/crosslink/broker/events"
"github.com/indexdata/crosslink/broker/handler"
Expand Down Expand Up @@ -205,10 +206,10 @@ func TestGetPatronRequestsNoSymbol(t *testing.T) {
}
handler.GetPatronRequests(rr, req, params)
assert.Equal(t, http.StatusOK, rr.Code)
if assert.NotNil(t, repo.cql) {
assert.Contains(t, *repo.cql, "side = lending")
assert.NotContains(t, *repo.cql, "supplier_symbol =")
assert.NotContains(t, *repo.cql, "requester_symbol =")
if assert.NotNil(t, repo.pgcql) {
assert.Contains(t, repo.pgcql.GetWhereClause(), "side =")
assert.NotContains(t, repo.pgcql.GetWhereClause(), "supplier_symbol =")
assert.NotContains(t, repo.pgcql.GetWhereClause(), "requester_symbol =")
}
}

Expand Down Expand Up @@ -244,10 +245,10 @@ func TestGetPatronRequestsWithRequesterReqId(t *testing.T) {
}
handler.GetPatronRequests(rr, req, params)
assert.Equal(t, http.StatusOK, rr.Code)
if assert.NotNil(t, repo.cql) {
assert.Contains(t, *repo.cql, "requester_req_id_exact = req-123")
assert.Contains(t, *repo.cql, "side = lending")
assert.Contains(t, *repo.cql, "supplier_symbol_exact = ISIL:REQ")
if assert.NotNil(t, repo.pgcql) {
assert.Contains(t, repo.pgcql.GetWhereClause(), "requester_req_id =")
assert.Contains(t, repo.pgcql.GetWhereClause(), "side =")
assert.Contains(t, repo.pgcql.GetWhereClause(), "supplier_symbol =")
}
}

Expand All @@ -265,8 +266,8 @@ func TestGetPatronRequestsWithSymbolNoSideGroupsOwnerRestriction(t *testing.T) {
handler.GetPatronRequests(rr, req, params)

assert.Equal(t, http.StatusOK, rr.Code)
if assert.NotNil(t, repo.cql) {
assert.Equal(t, "id = pr-1 and (side = lending and supplier_symbol_exact = ISIL:REQ or (side = borrowing and requester_symbol_exact = ISIL:REQ))", *repo.cql)
if assert.NotNil(t, repo.pgcql) {
assert.Equal(t, "id = $3 AND ((side = $4 AND supplier_symbol = $5) OR (side = $6 AND requester_symbol = $7))", repo.pgcql.GetWhereClause())
}
}

Expand Down Expand Up @@ -940,7 +941,7 @@ func (r *PrRepoOkapiOwner) GetPatronRequestSearchView(ctx common.ExtendedContext

type PrRepoCapture struct {
PrRepoError
cql *string
pgcql pgcql.Query
}

type PrRepoNotificationsCapture struct {
Expand All @@ -950,13 +951,13 @@ type PrRepoNotificationsCapture struct {
fullCount int64
}

func (r *PrRepoCapture) ListPatronRequests(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, cql *string) ([]pr_db.PatronRequest, int64, error) {
r.cql = cql
func (r *PrRepoCapture) ListPatronRequests(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, pgcql pgcql.Query) ([]pr_db.PatronRequest, int64, error) {
r.pgcql = pgcql
return []pr_db.PatronRequest{}, 0, nil
}

func (r *PrRepoCapture) ListPatronRequestsSearchView(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, cql *string) ([]pr_db.PatronRequestSearchView, int64, error) {
r.cql = cql
func (r *PrRepoCapture) ListPatronRequestsSearchView(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, pgcql pgcql.Query) ([]pr_db.PatronRequestSearchView, int64, error) {
r.pgcql = pgcql
return []pr_db.PatronRequestSearchView{}, 0, nil
}

Expand Down Expand Up @@ -986,11 +987,11 @@ func (r *PrRepoError) GetPatronRequestSearchView(ctx common.ExtendedContext, id
return patronRequestSearchViewFromPatronRequest(pr, false), err
}

func (r *PrRepoError) ListPatronRequests(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, cql *string) ([]pr_db.PatronRequest, int64, error) {
func (r *PrRepoError) ListPatronRequests(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, pgcql pgcql.Query) ([]pr_db.PatronRequest, int64, error) {
return []pr_db.PatronRequest{}, 0, errors.New("DB error")
}

func (r *PrRepoError) ListPatronRequestsSearchView(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, cql *string) ([]pr_db.PatronRequestSearchView, int64, error) {
func (r *PrRepoError) ListPatronRequestsSearchView(ctx common.ExtendedContext, args pr_db.ListPatronRequestsParams, pgcql pgcql.Query) ([]pr_db.PatronRequestSearchView, int64, error) {
return []pr_db.PatronRequestSearchView{}, 0, errors.New("DB error")
}

Expand Down Expand Up @@ -1044,15 +1045,15 @@ func (r *PrRepoError) GetNotificationById(ctx common.ExtendedContext, id string)
}
}

func (r *PrRepoError) GetPatronRequestsFacets(_ common.ExtendedContext, _ []string, _ string) ([]pr_db.Facet, error) {
func (r *PrRepoError) GetPatronRequestsFacets(_ common.ExtendedContext, _ []string, _ pgcql.Query) ([]pr_db.Facet, error) {
return nil, errors.New("DB error")
}

type PrRepoFacetsUnsupported struct {
PrRepoCapture
}

func (r *PrRepoFacetsUnsupported) GetPatronRequestsFacets(_ common.ExtendedContext, _ []string, _ string) ([]pr_db.Facet, error) {
func (r *PrRepoFacetsUnsupported) GetPatronRequestsFacets(_ common.ExtendedContext, _ []string, _ pgcql.Query) ([]pr_db.Facet, error) {
return nil, fmt.Errorf("%w: nosuch", pr_db.ErrUnsupportedFacet)
}

Expand Down
63 changes: 32 additions & 31 deletions broker/patron_request/db/prcql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (

var LANGUAGE = utils.GetEnv("LANGUAGE", "english")

const NumberBaseArgs = 2 // SQLC base query has two args: $1=limit, $2=offset

type FieldAllRecords struct{}

func (f *FieldAllRecords) GetColumn() string { return "" }
Expand Down Expand Up @@ -74,7 +76,10 @@ func (f *FieldTextArrayContains) Generate(sc cql.SearchClause, queryArgumentInde
}
}

func handlePatronRequestsQuery(cqlString string, noBaseArgs int) (pgcql.Query, error) {
// ParsePatronRequestsCql parses cqlString into a pgcql.Query whose placeholder
// numbering starts at $3, matching the two base SQL arguments (limit and offset)
// used by both ListPatronRequestsCql and GetPatronRequestsFacetsCql.
func ParsePatronRequestsCql(cqlString string) (pgcql.Query, error) {
def := pgcql.NewPgDefinition()

fa := &FieldAllRecords{}
Expand Down Expand Up @@ -167,37 +172,39 @@ func handlePatronRequestsQuery(cqlString string, noBaseArgs int) (pgcql.Query, e
if err != nil {
return nil, err
}
return def.Parse(query, noBaseArgs+1)
return def.Parse(query, NumberBaseArgs+1)
}

// facetFieldPlaceholder is the column name used in the facetsPatronRequests SQL template.
// FacetsPatronRequestsCql substitutes it with the validated facet field at runtime.
// facetFieldPlaceholder is the column name used in the getPatronRequestsFacets SQL template.
// GetPatronRequestsFacetsCql substitutes it with the validated facet field at runtime.
const facetFieldPlaceholder = "requester_symbol"

func (q *Queries) FacetsPatronRequestsCql(ctx context.Context, db DBTX, facetField string, cqlString string) ([]FacetsPatronRequestsRow, error) {
func (q *Queries) GetPatronRequestsFacetsCql(ctx context.Context, db DBTX, facetField string, pgcql pgcql.Query) ([]GetPatronRequestsFacetsRow, error) {
if pgcql == nil {
return nil, fmt.Errorf("pgcql.Query must not be nil; use cql.allRecords=1 for no filter")
}
// facetField is validated against an allowlist by the caller (GetPatronRequestsFacets),
// so it is safe to substitute directly as a column name.
sql := strings.Replace(facetsPatronRequests, facetFieldPlaceholder, facetField, 1)
sql := strings.Replace(getPatronRequestsFacets, facetFieldPlaceholder, facetField, 1)

idx := strings.Index(sql, "GROUP BY")
if idx == -1 {
return nil, fmt.Errorf("base SQL query missing GROUP BY clause")
}
res, err := handlePatronRequestsQuery(cqlString, 0)
if err != nil {
return nil, fmt.Errorf("failed to handle CQL query: %w", err)
}
if res.GetWhereClause() != "" {
sql = sql[:idx] + "AND (" + res.GetWhereClause() + ") " + sql[idx:]
if pgcql.GetWhereClause() != "" {
sql = sql[:idx] + "AND (" + pgcql.GetWhereClause() + ") " + sql[idx:]
}
rows, err := db.Query(ctx, sql, res.GetQueryArguments()...)
sqlArguments := make([]interface{}, 0, NumberBaseArgs+len(pgcql.GetQueryArguments()))
sqlArguments = append(sqlArguments, int64(100), int64(0)) // 100 facet values should be more than enough; offset is always 0 for facets
sqlArguments = append(sqlArguments, pgcql.GetQueryArguments()...)
rows, err := db.Query(ctx, sql, sqlArguments...)
if err != nil {
return nil, fmt.Errorf("failed to execute facets query: %w", err)
}
defer rows.Close()
var items []FacetsPatronRequestsRow
var items []GetPatronRequestsFacetsRow
for rows.Next() {
var i FacetsPatronRequestsRow
var i GetPatronRequestsFacetsRow
if err := rows.Scan(&i.Value, &i.Count); err != nil {
return nil, err
}
Expand All @@ -210,15 +217,9 @@ func (q *Queries) FacetsPatronRequestsCql(ctx context.Context, db DBTX, facetFie
}

func (q *Queries) ListPatronRequestsCql(ctx context.Context, db DBTX, arg ListPatronRequestsParams,
cqlString *string, explainAnalyze bool) ([]ListPatronRequestsRow, []string, error) {
if cqlString == nil {
rows, err := q.ListPatronRequests(ctx, db, arg)
return rows, nil, err
}
noBaseArgs := 2 // we have two base arguments: limit and offset
res, err := handlePatronRequestsQuery(*cqlString, noBaseArgs)
if err != nil {
return nil, nil, err
pgcql pgcql.Query, explainAnalyze bool) ([]ListPatronRequestsRow, []string, error) {
if pgcql == nil {
return nil, nil, fmt.Errorf("pgcql.Query must not be nil; use cql.allRecords=1 for no filter")
}
orgSql := listPatronRequests
pos := strings.Index(orgSql, "ORDER BY")
Expand All @@ -230,21 +231,21 @@ func (q *Queries) ListPatronRequestsCql(ctx context.Context, db DBTX, arg ListPa
return nil, nil, fmt.Errorf("base query missing LIMIT")
}
orderBy := orgSql[pos:limitPos]
if res.GetOrderByClause() != "" {
orderBy = res.GetOrderByClause() + " "
if pgcql.GetOrderByClause() != "" {
orderBy = pgcql.GetOrderByClause() + " "
}
sqlPrefix := orgSql[:pos]
if res.GetWhereClause() != "" {
if pgcql.GetWhereClause() != "" {
if strings.Contains(strings.ToUpper(sqlPrefix), "WHERE ") {
sqlPrefix += "AND " + res.GetWhereClause() + " "
sqlPrefix += "AND " + pgcql.GetWhereClause() + " "
} else {
sqlPrefix += "WHERE " + res.GetWhereClause() + " "
sqlPrefix += "WHERE " + pgcql.GetWhereClause() + " "
}
Comment thread
jakub-id marked this conversation as resolved.
}
sql := sqlPrefix + orderBy + orgSql[limitPos:]
sqlArguments := make([]interface{}, 0, noBaseArgs+len(res.GetQueryArguments()))
sqlArguments := make([]interface{}, 0, NumberBaseArgs+len(pgcql.GetQueryArguments()))
sqlArguments = append(sqlArguments, arg.Limit, arg.Offset)
sqlArguments = append(sqlArguments, res.GetQueryArguments()...)
sqlArguments = append(sqlArguments, pgcql.GetQueryArguments()...)
explainResult := []string{}
if explainAnalyze {
explainRows, err := db.Query(ctx, "EXPLAIN ANALYZE "+sql, sqlArguments...)
Expand Down
17 changes: 6 additions & 11 deletions broker/patron_request/db/prcql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"testing"

"github.com/indexdata/cql-go/cql"
"github.com/stretchr/testify/assert"
)

func TestHandlePatronRequestsQueryKeepsOwnerRestrictionGrouped(t *testing.T) {
cql := "cql.allRecords = 1 and (side = lending and supplier_symbol_exact = ISIL:REQ or (side = borrowing and requester_symbol_exact = ISIL:REQ))"

query, err := handlePatronRequestsQuery(cql, 2)
if err != nil {
t.Fatalf("handlePatronRequestsQuery() error = %v", err)
}
query, err := ParsePatronRequestsCql(cql)
assert.NoError(t, err, "ParsePatronRequestsCQL() error = %v", err)

want := "TRUE AND ((side = $3 AND supplier_symbol = $4) OR (side = $5 AND requester_symbol = $6))"
if got := query.GetWhereClause(); got != want {
Expand Down Expand Up @@ -87,15 +86,11 @@ func TestFieldTextArrayContainsGenerate(t *testing.T) {
func TestHandlePatronRequestsQueryIsbnUsesNormIsxn(t *testing.T) {
cql := `isbn = "978-3-16-148410-0"`

query, err := handlePatronRequestsQuery(cql, 2)
if err != nil {
t.Fatalf("handlePatronRequestsQuery() error = %v", err)
}
query, err := ParsePatronRequestsCql(cql)
assert.NoError(t, err, "ParsePatronRequestsCQL() error = %v", err)

wantWhere := "bibliographic_item_identifiers(ill_request, 'ISBN') @> ARRAY[norm_isxn($3)]::text[]"
if got := query.GetWhereClause(); got != wantWhere {
t.Fatalf("where clause = %q, want %q", got, wantWhere)
}
assert.Equal(t, wantWhere, query.GetWhereClause(), "where clause = %q, want %q", query.GetWhereClause(), wantWhere)
}

func searchClauseForTest(term, relation string) cql.SearchClause {
Expand Down
25 changes: 13 additions & 12 deletions broker/patron_request/db/prrepo.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

"github.com/indexdata/cql-go/pgcql"
"github.com/indexdata/crosslink/broker/common"
"github.com/indexdata/crosslink/broker/repo"
"github.com/jackc/pgx/v5"
Expand All @@ -18,9 +19,9 @@ type PrRepo interface {
GetPatronRequestSearchView(ctx common.ExtendedContext, id string) (PatronRequestSearchView, error)
GetPatronRequestByIdForUpdate(ctx common.ExtendedContext, id string) (PatronRequest, error)
GetPatronRequestByIdAndSide(ctx common.ExtendedContext, id string, side PatronRequestSide) (PatronRequest, error)
ListPatronRequests(ctx common.ExtendedContext, args ListPatronRequestsParams, cql *string) ([]PatronRequest, int64, error)
ListPatronRequestsSearchView(ctx common.ExtendedContext, args ListPatronRequestsParams, cql *string) ([]PatronRequestSearchView, int64, error)
GetPatronRequestsFacets(ctx common.ExtendedContext, facetFields []string, cql string) ([]Facet, error)
ListPatronRequests(ctx common.ExtendedContext, args ListPatronRequestsParams, pgcql pgcql.Query) ([]PatronRequest, int64, error)
ListPatronRequestsSearchView(ctx common.ExtendedContext, args ListPatronRequestsParams, pgcql pgcql.Query) ([]PatronRequestSearchView, int64, error)
GetPatronRequestsFacets(ctx common.ExtendedContext, facetFields []string, pgcql pgcql.Query) ([]Facet, error)
UpdatePatronRequest(ctx common.ExtendedContext, params UpdatePatronRequestParams) (PatronRequest, error)
UpdatePatronRequestInternalNote(ctx common.ExtendedContext, id string, internalNote pgtype.Text) error
CreatePatronRequest(ctx common.ExtendedContext, params CreatePatronRequestParams) (PatronRequest, error)
Expand Down Expand Up @@ -105,8 +106,8 @@ func (r *PgPrRepo) GetPatronRequestByIdAndSide(ctx common.ExtendedContext, id st
return pr, nil
}

func (r *PgPrRepo) ListPatronRequests(ctx common.ExtendedContext, params ListPatronRequestsParams, cql *string) ([]PatronRequest, int64, error) {
rows, fullCount, err := r.listPatronRequestRows(ctx, params, cql)
func (r *PgPrRepo) ListPatronRequests(ctx common.ExtendedContext, params ListPatronRequestsParams, pgcql pgcql.Query) ([]PatronRequest, int64, error) {
rows, fullCount, err := r.listPatronRequestRows(ctx, params, pgcql)
if err != nil {
return nil, fullCount, err
}
Expand All @@ -117,12 +118,12 @@ func (r *PgPrRepo) ListPatronRequests(ctx common.ExtendedContext, params ListPat
return list, fullCount, nil
}

func (r *PgPrRepo) GetPatronRequestsFacets(ctx common.ExtendedContext, facetFields []string, cql string) ([]Facet, error) {
func (r *PgPrRepo) GetPatronRequestsFacets(ctx common.ExtendedContext, facetFields []string, pgcql pgcql.Query) ([]Facet, error) {
var facets []Facet
for _, field := range facetFields {
switch field {
case "requester_symbol", "supplier_symbol":
rows, err := r.queries.FacetsPatronRequestsCql(ctx, r.GetConnOrTx(), field, cql)
rows, err := r.queries.GetPatronRequestsFacetsCql(ctx, r.GetConnOrTx(), field, pgcql)
if err != nil {
return nil, err
}
Expand All @@ -146,8 +147,8 @@ func (r *PgPrRepo) GetPatronRequestsFacets(ctx common.ExtendedContext, facetFiel
return facets, nil
}

func (r *PgPrRepo) ListPatronRequestsSearchView(ctx common.ExtendedContext, params ListPatronRequestsParams, cql *string) ([]PatronRequestSearchView, int64, error) {
rows, fullCount, err := r.listPatronRequestRows(ctx, params, cql)
func (r *PgPrRepo) ListPatronRequestsSearchView(ctx common.ExtendedContext, params ListPatronRequestsParams, pgcql pgcql.Query) ([]PatronRequestSearchView, int64, error) {
rows, fullCount, err := r.listPatronRequestRows(ctx, params, pgcql)
if err != nil {
return nil, fullCount, err
}
Expand All @@ -158,8 +159,8 @@ func (r *PgPrRepo) ListPatronRequestsSearchView(ctx common.ExtendedContext, para
return list, fullCount, nil
}

func (r *PgPrRepo) listPatronRequestRows(ctx common.ExtendedContext, params ListPatronRequestsParams, cql *string) ([]ListPatronRequestsRow, int64, error) {
rows, explainResult, err := r.queries.ListPatronRequestsCql(ctx, r.GetConnOrTx(), params, cql, r.explainAnalyze)
func (r *PgPrRepo) listPatronRequestRows(ctx common.ExtendedContext, params ListPatronRequestsParams, pgcql pgcql.Query) ([]ListPatronRequestsRow, int64, error) {
rows, explainResult, err := r.queries.ListPatronRequestsCql(ctx, r.GetConnOrTx(), params, pgcql, r.explainAnalyze)
var fullCount int64
if err == nil {
for _, line := range explainResult {
Expand All @@ -170,7 +171,7 @@ func (r *PgPrRepo) listPatronRequestRows(ctx common.ExtendedContext, params List
} else {
params.Limit = 1
params.Offset = 0
countRows, _, countErr := r.queries.ListPatronRequestsCql(ctx, r.GetConnOrTx(), params, cql, false)
countRows, _, countErr := r.queries.ListPatronRequestsCql(ctx, r.GetConnOrTx(), params, pgcql, false)
err = countErr
if err == nil && len(countRows) > 0 {
fullCount = countRows[0].FullCount
Expand Down
10 changes: 9 additions & 1 deletion broker/pullslip/api/api_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package psapi
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -215,7 +216,14 @@ func (p PullSlipApiHandler) getPullSlip(ctx common.ExtendedContext, w http.Respo
}

func (p PullSlipApiHandler) getPdfByte(ctx common.ExtendedContext, w http.ResponseWriter, cql string) ([]byte, error) {
prs, _, err := p.prRepo.ListPatronRequests(ctx, pr_db.ListPatronRequestsParams{Limit: MAX_RECORDS_PER_PDF, Offset: 0}, &cql)
pgcql, err := pr_db.ParsePatronRequestsCql(cql)
if err != nil {
wrappedErr := fmt.Errorf("invalid CQL query: %w", err)
api.AddBadRequestError(ctx, w, wrappedErr)
return []byte{}, wrappedErr
}

prs, _, err := p.prRepo.ListPatronRequests(ctx, pr_db.ListPatronRequestsParams{Limit: MAX_RECORDS_PER_PDF, Offset: 0}, pgcql)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
api.AddNotFoundError(w)
Expand Down
Loading
Loading