diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 903211e..87d7767 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: "1.24.x" + go-version: "1.26.x" - name: Check out code uses: actions/checkout@v4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2dadb8d..f38af2d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: postgres-test: strategy: matrix: - go: [1.24.x, 1.23.x] # when updating versions, update it below too. + go: [1.26.x, 1.25.x] # when updating versions, update it below too. runs-on: ubuntu-latest services: postgres: @@ -42,7 +42,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.24.x' + go-version: '1.26.x' - name: Run unit tests run: | go test -v -race -count 1 -covermode atomic -coverprofile=profile.cov ./... @@ -51,7 +51,7 @@ jobs: working-directory: integration run: go test -v - name: Code coverage - if: ${{ github.event_name != 'pull_request' && matrix.go == '1.24.x' }} + if: ${{ github.event_name != 'pull_request' && matrix.go == '1.26.x' }} uses: shogo82148/actions-goveralls@v1 with: path-to-profile: profile.cov diff --git a/delete.go b/delete.go index 4310c4f..cba1470 100644 --- a/delete.go +++ b/delete.go @@ -182,7 +182,6 @@ func (b DeleteBuilder) Returning(columns ...string) DeleteBuilder { // ReturningSelect adds a RETURNING expressions to the query similar to Using, but takes a Select statement. func (b DeleteBuilder) ReturningSelect(from SelectBuilder, alias string) DeleteBuilder { - from.placeholder = questionPlaceholder b.returning = append(b.returning, Alias{Expr: from, As: alias}) return b } diff --git a/delete_test.go b/delete_test.go index 0e9ceea..0f29476 100644 --- a/delete_test.go +++ b/delete_test.go @@ -73,6 +73,14 @@ func TestDeleteBuilderSQL(t *testing.T) { wantSQL: "DELETE FROM films USING (SELECT id FROM producers WHERE name = $1) AS p", wantArgs: []any{"foo"}, }, + { + name: "delete_using_select_params", + b: Delete("films"). + UsingSelect(Select("id").From("producers").Where("name = ?", "foo"), "p"). + Where("status = ?", "active"), + wantSQL: "DELETE FROM films USING (SELECT id FROM producers WHERE name = $1) AS p WHERE status = $2", + wantArgs: []any{"foo", "active"}, + }, { name: "delete_with_cte", b: Delete("orders"). diff --git a/expr.go b/expr.go index 95db5c7..5e29500 100644 --- a/expr.go +++ b/expr.go @@ -63,7 +63,7 @@ func (e expr) SQL() (sql string, args []any, err error) { if as, ok := ap[0].(SQLizer); ok { // sqlizer argument; expand it and append the result - isql, iargs, err = as.SQL() + isql, iargs, err = nestedSQL(as) buf.WriteString(sp[:i]) buf.WriteString(isql) args = append(args, iargs...) @@ -95,7 +95,7 @@ func ConcatSQL(ce ...any) (sql string, args []any, err error) { case string: sql += p case SQLizer: - pSQL, pArgs, err := p.SQL() + pSQL, pArgs, err := nestedSQL(p) if err != nil { return "", nil, err } @@ -120,7 +120,7 @@ type Alias struct { // AliasExprSQL returns a SQL query based on the alias. func (a Alias) SQL() (sql string, args []any, err error) { - sql, args, err = a.Expr.SQL() + sql, args, err = nestedSQL(a.Expr) if err == nil { sql = fmt.Sprintf("(%s) AS %s", sql, a.As) } @@ -162,6 +162,17 @@ func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) { if val, err = v.Value(); err != nil { return } + case SQLizer: + var vsql string + var vargs []any + vsql, vargs, err = nestedSQL(v) + if err != nil { + return + } + expr = fmt.Sprintf("%s %s (%s)", key, equalOpr, vsql) + args = append(args, vargs...) + exprs = append(exprs, expr) + continue } r := reflect.ValueOf(val) @@ -215,7 +226,7 @@ func (neq NotEq) SQL() (sql string, args []any, err error) { // Like is syntactic sugar for use with LIKE conditions. // Ex: // -// .Where(Like{"name": "%irrel"}) +// .Where(Like{"name": "%elephant"}) type Like map[string]any func (lk Like) toSQL(opr string) (sql string, args []any, err error) { @@ -228,6 +239,17 @@ func (lk Like) toSQL(opr string) (sql string, args []any, err error) { if val, err = v.Value(); err != nil { return } + case SQLizer: + var vsql string + var vargs []any + vsql, vargs, err = nestedSQL(v) + if err != nil { + return + } + expr = fmt.Sprintf("%s %s (%s)", key, opr, vsql) + args = append(args, vargs...) + exprs = append(exprs, expr) + continue } if val == nil { @@ -255,7 +277,7 @@ func (lk Like) SQL() (sql string, args []any, err error) { // NotLike is syntactic sugar for use with LIKE conditions. // Ex: // -// .Where(NotLike{"name": "%irrel"}) +// .Where(NotLike{"name": "%elephant"}) type NotLike Like func (nlk NotLike) SQL() (sql string, args []any, err error) { @@ -312,6 +334,17 @@ func (lt Lt) toSQL(opposite, orEq bool) (sql string, args []any, err error) { if val, err = v.Value(); err != nil { return } + case SQLizer: + var vsql string + var vargs []any + vsql, vargs, err = nestedSQL(v) + if err != nil { + return + } + expr = fmt.Sprintf("%s %s (%s)", key, opr, vsql) + args = append(args, vargs...) + exprs = append(exprs, expr) + continue } if val == nil { diff --git a/expr_test.go b/expr_test.go index 00b980c..fbe764d 100644 --- a/expr_test.go +++ b/expr_test.go @@ -108,6 +108,44 @@ func TestNotEqSQL(t *testing.T) { } } +func TestEqSubquerySQL(t *testing.T) { + t.Parallel() + b := Eq{"id": Select("id").From("other").Where("x = ?", 1)} + sql, args, err := b.SQL() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + want := "id = (SELECT id FROM other WHERE x = ?)" + if want != sql { + t.Errorf("expected SQL to be %q, got %q instead", want, sql) + } + + expectedArgs := []any{1} + if !reflect.DeepEqual(expectedArgs, args) { + t.Errorf("wanted %v, got %v instead", expectedArgs, args) + } +} + +func TestNotEqSubquerySQL(t *testing.T) { + t.Parallel() + b := NotEq{"id": Select("id").From("other").Where("x = ?", 1)} + sql, args, err := b.SQL() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + want := "id <> (SELECT id FROM other WHERE x = ?)" + if want != sql { + t.Errorf("expected SQL to be %q, got %q instead", want, sql) + } + + expectedArgs := []any{1} + if !reflect.DeepEqual(expectedArgs, args) { + t.Errorf("wanted %v, got %v instead", expectedArgs, args) + } +} + func TestEqNotInSQL(t *testing.T) { t.Parallel() b := NotEq{"id": []int{1, 2, 3}} @@ -203,6 +241,43 @@ func TestLtSQL(t *testing.T) { } } +func TestLtSubquerySQL(t *testing.T) { + t.Parallel() + b := Lt{"score": Select("avg(score)").From("results").Where("active = ?", true)} + sql, args, err := b.SQL() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + want := "score < (SELECT avg(score) FROM results WHERE active = ?)" + if want != sql { + t.Errorf("expected SQL to be %q, got %q instead", want, sql) + } + + expectedArgs := []any{true} + if !reflect.DeepEqual(expectedArgs, args) { + t.Errorf("wanted %v, got %v instead", expectedArgs, args) + } +} + +func TestGtSubquerySQL(t *testing.T) { + t.Parallel() + b := Gt{"score": Select("avg(score)").From("results")} + sql, args, err := b.SQL() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + want := "score > (SELECT avg(score) FROM results)" + if want != sql { + t.Errorf("expected SQL to be %q, got %q instead", want, sql) + } + + if len(args) != 0 { + t.Errorf("wanted 0 arguments, got %d instead", len(args)) + } +} + func TestLtOrEqSQL(t *testing.T) { t.Parallel() b := LtOrEq{"id": 1} diff --git a/insert.go b/insert.go index 809c3e1..2104ac8 100644 --- a/insert.go +++ b/insert.go @@ -123,10 +123,13 @@ func (b InsertBuilder) appendValuesToSQL(w io.Writer, args []any) ([]any, error) valueStrings := make([]string, len(row)) for v, val := range row { if vs, ok := val.(SQLizer); ok { - vsql, vargs, err := vs.SQL() + vsql, vargs, err := nestedSQL(vs) if err != nil { return nil, err } + if _, ok := vs.(rawSQLizer); ok { + vsql = fmt.Sprintf("(%s)", vsql) + } valueStrings[v] = vsql args = append(args, vargs...) } else { @@ -147,7 +150,7 @@ func (b InsertBuilder) appendSelectToSQL(w io.Writer, args []any) ([]any, error) return args, errors.New("select clause for insert statements are not set") } - selectClause, sArgs, err := b.selectBuilder.SQL() + selectClause, sArgs, err := b.selectBuilder.unfinalizedSQL() if err != nil { return args, err } @@ -221,7 +224,6 @@ func (b InsertBuilder) Returning(columns ...string) InsertBuilder { // ReturningSelect adds a RETURNING expressions to the query similar to Using, but takes a Select statement. func (b InsertBuilder) ReturningSelect(from SelectBuilder, alias string) InsertBuilder { - from.placeholder = questionPlaceholder b.returning = append(b.returning, Alias{Expr: from, As: alias}) return b } diff --git a/insert_test.go b/insert_test.go index 1aac860..d1b2acd 100644 --- a/insert_test.go +++ b/insert_test.go @@ -108,6 +108,45 @@ func TestInsertBuilderSQL(t *testing.T) { "SELECT s.id, s.data FROM source s JOIN tree ON tree.id = s.parent_id) " + "INSERT INTO archive (id,data) SELECT id, data FROM tree", }, + { + name: "insert_select_params", + b: Insert("films"). + Columns("id", "title"). + Select(Select("id", "title").From("producers").Where("name = ?", "foo")), + wantSQL: "INSERT INTO films (id,title) SELECT id, title FROM producers WHERE name = $1", + wantArgs: []any{"foo"}, + }, + { + name: "insert_values_select_params", + b: Insert("films"). + Columns("id", "title"). + Values(1, Select("title").From("other").Where("id = ?", 2)). + Suffix("RETURNING id, ?", 3), + wantSQL: "INSERT INTO films (id,title) VALUES ($1,(SELECT title FROM other WHERE id = $2)) RETURNING id, $3", + wantArgs: []any{1, 2, 3}, + }, + { + name: "insert_values_union_params", + b: Insert("films"). + Columns("id", "title"). + Values(1, Union( + Select("title").From("other").Where("id = ?", 2), + Select("title").From("another").Where("id = ?", 3), + )), + wantSQL: "INSERT INTO films (id,title) VALUES ($1,(SELECT title FROM other WHERE id = $2 UNION SELECT title FROM another WHERE id = $3))", + wantArgs: []any{1, 2, 3}, + }, + { + name: "insert_values_union", + b: Insert("t"). + Columns("id"). + Values(Union( + Select("id").From("a").Where("x = ?", 1), + Select("id").From("b").Where("x = ?", 2), + )), + wantSQL: "INSERT INTO t (id) VALUES ((SELECT id FROM a WHERE x = $1 UNION SELECT id FROM b WHERE x = $2))", + wantArgs: []any{1, 2}, + }, } for _, tc := range testCases { tc := tc @@ -242,27 +281,11 @@ func TestInsertBuilderSelect(t *testing.T) { } } -func TestInsertBuilderReplace(t *testing.T) { - t.Parallel() - b := Replace("table").Values(1) - - want := "REPLACE INTO table VALUES ($1)" - - sql, _, err := b.SQL() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if want != sql { - t.Errorf("expected SQL to be %q, got %q instead", want, sql) - } -} - func TestInsertBuilderVerb(t *testing.T) { t.Parallel() - b := Insert("table").Verb("REPLACE").Values(1) + b := Insert("table").Verb("UPSERT").Values(1) - want := "REPLACE INTO table VALUES ($1)" + want := "UPSERT INTO table VALUES ($1)" sql, _, err := b.SQL() if err != nil { diff --git a/pgq.go b/pgq.go index bda67e5..78c4c90 100644 --- a/pgq.go +++ b/pgq.go @@ -34,7 +34,7 @@ type rawSQLizer interface { // not try very hard to ensure it. Additionally, executing the output of this // function with any untrusted user input is certainly insecure. func Debug(s SQLizer) string { - sql, args, err := s.SQL() + sql, args, err := nestedSQL(s) if err != nil { return fmt.Sprintf("[SQL error: %s]", err) } diff --git a/pgq_test.go b/pgq_test.go index ae05c52..c5982dc 100644 --- a/pgq_test.go +++ b/pgq_test.go @@ -13,6 +13,15 @@ func TestDebug(t *testing.T) { } } +func TestDebugSelect(t *testing.T) { + t.Parallel() + sqlizer := Select("id", "name").From("users").Where("id = ?", 42).Where("active = ?", true) + want := "SELECT id, name FROM users WHERE id = '42' AND active = 'true'" + if got := Debug(sqlizer); got != want { + t.Errorf("expected %q, got %q instead", want, got) + } +} + func TestDebugSQLizerErrors(t *testing.T) { t.Parallel() var errorMessages = []struct { diff --git a/placeholder.go b/placeholder.go index 6034484..adb1528 100644 --- a/placeholder.go +++ b/placeholder.go @@ -2,19 +2,10 @@ package pgq import ( "bytes" - "fmt" + "strconv" "strings" ) -// placeholder takes a SQL statement and replaces each question mark -// placeholder with a (possibly different) SQL placeholder. -type placeholder func(sql string) (string, error) - -// questionPlaceholder just leaves question marks ("?") as placeholders. -func questionPlaceholder(sql string) (string, error) { - return sql, nil -} - // Placeholders returns a string with count ? placeholders joined with commas. func Placeholders(count int) string { if count < 1 { @@ -29,6 +20,7 @@ func Placeholders(count int) string { func dollarPlaceholder(sql string) (string, error) { buf := &bytes.Buffer{} i := 0 + var itob [20]byte for { p := strings.Index(sql, "?") if p == -1 { @@ -45,7 +37,8 @@ func dollarPlaceholder(sql string) (string, error) { } else { i++ buf.WriteString(sql[:p]) - fmt.Fprintf(buf, "$%d", i) + buf.WriteByte('$') + buf.Write(strconv.AppendInt(itob[:0], int64(i), 10)) sql = sql[p+1:] } } diff --git a/select.go b/select.go index 54c4960..33f7503 100644 --- a/select.go +++ b/select.go @@ -8,7 +8,6 @@ import ( // SelectBuilder builds SQL SELECT statements. type SelectBuilder struct { - placeholder placeholder ctes []cte prefixes []SQLizer options []string @@ -31,11 +30,7 @@ func (b SelectBuilder) SQL() (sqlStr string, args []any, err error) { return } - f := b.placeholder - if f == nil { - f = dollarPlaceholder - } - sqlStr, err = f(sqlStr) + sqlStr, err = dollarPlaceholder(sqlStr) return } @@ -219,7 +214,7 @@ func (b SelectBuilder) RemoveColumns() SelectBuilder { // Unlike Columns, Column accepts args which will be bound to placeholders in // the columns string, for example: // -// Column("IF(col IN ("+pgq.Placeholders(3)+"), 1, 0) as col", 1, 2, 3) +// Column("CASE WHEN col IN ("+pgq.Placeholders(3)+") THEN 1 ELSE 0 END as col", 1, 2, 3) func (b SelectBuilder) Column(column any, args ...any) SelectBuilder { b.columns = append(b.columns, newPart(column, args...)) return b @@ -233,9 +228,6 @@ func (b SelectBuilder) From(from string) SelectBuilder { // FromSelect sets a subquery into the FROM clause of the query. func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder { - // Prevent misnumbered parameters in nested selects - // See https://github.com/Masterminds/squirrel/issues/183 - from.placeholder = questionPlaceholder b.from = Alias{ Expr: from, As: alias, diff --git a/select_test.go b/select_test.go index 8d76109..1b01e60 100644 --- a/select_test.go +++ b/select_test.go @@ -453,6 +453,29 @@ func TestRemoveColumns(t *testing.T) { } } +func TestSelectBuilderColumnAliasSubqueryParams(t *testing.T) { + t.Parallel() + subQ := Select("name").From("producers").Where("active = ?", true) + b := Select("id"). + Column(Alias{Expr: subQ, As: "producer_name"}). + From("films"). + Where("id = ?", 42) + sql, args, err := b.SQL() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + want := "SELECT id, (SELECT name FROM producers WHERE active = $1) AS producer_name FROM films WHERE id = $2" + if sql != want { + t.Errorf("expected SQL to be %q, got %q instead", want, sql) + } + + expectedArgs := []any{true, 42} + if !reflect.DeepEqual(args, expectedArgs) { + t.Errorf("wanted %v, got %v instead", expectedArgs, args) + } +} + func TestSelectBuilder_PrefixExpr_NestedUpdateDollar(t *testing.T) { t.Parallel() nestedBuilder := Update("foo").Prefix("WITH updated AS ("). diff --git a/statement.go b/statement.go index 4626306..548aa2c 100644 --- a/statement.go +++ b/statement.go @@ -53,16 +53,6 @@ func Insert(into string) InsertBuilder { return InsertBuilder{into: into} } -// Replace returns a new InsertBuilder with the statement keyword set to -// "REPLACE" and with the given table name. -// -// See InsertBuilder.Into. -func Replace(into string) InsertBuilder { - builder := InsertBuilder{} - builder.verb = "REPLACE" - return builder.Into(into) -} - // Update returns a new UpdateBuilder with the given table name. // // See UpdateBuilder.Table. diff --git a/update.go b/update.go index 0c1741e..f9f4f4b 100644 --- a/update.go +++ b/update.go @@ -71,11 +71,11 @@ func (b UpdateBuilder) unfinalizedSQL() (sqlStr string, args []any, err error) { for i, setClause := range b.setClauses { var valSQL string if vs, ok := setClause.value.(SQLizer); ok { - vsql, vargs, err := vs.SQL() + vsql, vargs, err := nestedSQL(vs) if err != nil { return "", nil, err } - if _, ok := vs.(SelectBuilder); ok { + if _, ok := vs.(rawSQLizer); ok { valSQL = fmt.Sprintf("(%s)", vsql) } else { valSQL = vsql @@ -238,7 +238,6 @@ func (b UpdateBuilder) Returning(columns ...string) UpdateBuilder { // ReturningSelect adds a RETURNING expressions to the query similar to Using, but takes a Select statement. func (b UpdateBuilder) ReturningSelect(from SelectBuilder, alias string) UpdateBuilder { - from.placeholder = questionPlaceholder b.returning = append(b.returning, Alias{Expr: from, As: alias}) return b } diff --git a/update_test.go b/update_test.go index 9e72fc9..1edca51 100644 --- a/update_test.go +++ b/update_test.go @@ -125,6 +125,18 @@ func TestUpdateBuilderSQL(t *testing.T) { "AS acc WHERE acc.name = $1 AND employees.id = acc.sales_person", wantArgs: []any{"Acme Corporation"}, }, + { + name: "from_select_params", + b: Update("employees").Set("sales_count", Expr("sales_count + 1")).FromSelect( + Select("name").From("accounts").Where("status = ?", "active"), "acc", + ). + Where("acc.name = ?", "Acme Corporation"). + Where("employees.id = acc.sales_person"), + wantSQL: "UPDATE employees SET sales_count = sales_count + 1 " + + "FROM (SELECT name FROM accounts WHERE status = $1) " + + "AS acc WHERE acc.name = $2 AND employees.id = acc.sales_person", + wantArgs: []any{"active", "Acme Corporation"}, + }, { name: "with_cte", b: Update("employees"). @@ -154,6 +166,25 @@ func TestUpdateBuilderSQL(t *testing.T) { "UPDATE employees SET is_manager = $1 WHERE id IN (SELECT id FROM mgrs)", wantArgs: []any{true}, }, + { + name: "update_set_select_params", + b: Update("films"). + Set("producer_id", Select("id").From("producers").Where("name = ?", "foo")). + Where("id = ?", 123), + wantSQL: "UPDATE films SET producer_id = (SELECT id FROM producers WHERE name = $1) WHERE id = $2", + wantArgs: []any{"foo", 123}, + }, + { + name: "update_set_union", + b: Update("t"). + Set("val", Union( + Select("v").From("a").Where("x = ?", 1), + Select("v").From("b").Where("x = ?", 2), + )). + Where("id = ?", 3), + wantSQL: "UPDATE t SET val = (SELECT v FROM a WHERE x = $1 UNION SELECT v FROM b WHERE x = $2) WHERE id = $3", + wantArgs: []any{1, 2, 3}, + }, } for _, tc := range testCases { diff --git a/where.go b/where.go index 1f96397..a1d8851 100644 --- a/where.go +++ b/where.go @@ -14,10 +14,8 @@ func (p wherePart) SQL() (sql string, args []any, err error) { switch pred := p.pred.(type) { case nil: // no-op - case rawSQLizer: - return pred.unfinalizedSQL() case SQLizer: - return pred.SQL() + return nestedSQL(pred) case map[string]any: return Eq(pred).SQL() case string: