Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/go-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ jobs:

- name: Test
run: go test -v ./...

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: emicklei/pgtalk
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# pgtalk

[![Go](https://github.com/emicklei/pgtalk/actions/workflows/go-test.yml/badge.svg)](https://github.com/emicklei/pgtalk/actions/workflows/go-test.yml)
[![Go Report Card](https://goreportcard.com/badge/github.com/emicklei/pgtalk)](https://goreportcard.com/report/github.com/emicklei/pgtalk)
[![GoDoc](https://pkg.go.dev/badge/github.com/emicklei/pgtalk)](https://pkg.go.dev/github.com/emicklei/pgtalk)
[![codecov](https://codecov.io/gh/emicklei/pgtalk/branch/master/graph/badge.svg)](https://codecov.io/gh/emicklei/pgtalk)

More type safe SQL query building and execution using Go code generated (pgtalk-gen) from PostgreSQL table definitions.
After code generation, you get a Go type for each table or view with functions to create a QuerySet or MutationSet value.
Expand Down
20 changes: 7 additions & 13 deletions iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
)

type resultIterator[T any] struct {
queryError error
commandTag pgconn.CommandTag
rows pgx.Rows
selectors []ColumnAccessor
params []any
queryError error
commandTag pgconn.CommandTag
rows pgx.Rows
orderedSelectors []ColumnAccessor
params []any
}

// Close closes the rows, making the connection ready for use again. It is safe
Expand Down Expand Up @@ -55,15 +55,9 @@ func (i *resultIterator[T]) HasNext() bool {

func (i *resultIterator[T]) Next() (*T, error) {
entity := new(T)
list := i.rows.FieldDescriptions()
// order of list is not the same as selectors?
toScan := []any{}
for _, each := range list {
for _, other := range i.selectors {
if other.Column().columnName == each.Name {
toScan = append(toScan, other.FieldValueToScan(entity))
}
}
for _, each := range i.orderedSelectors {
toScan = append(toScan, each.FieldValueToScan(entity))
}
if err := i.rows.Scan(toScan...); err != nil {
return nil, err
Expand Down
134 changes: 134 additions & 0 deletions iterator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package pgtalk

import (
"errors"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

func TestResultIterator_Next(t *testing.T) {
t.Run("happy flow", func(t *testing.T) {
rows := &mockRows{
nextResult: true,
scanError: nil,
closeCalled: false,
}
selectors := []ColumnAccessor{
mockColumnAccessor{
fieldValueToScan: func(entity any) any {
return &entity.(*testEntity).ID
},
},
}
i := &resultIterator[testEntity]{
rows: rows,
orderedSelectors: selectors,
}
if !i.HasNext() {
t.Fatal("expected next")
}
entity, err := i.Next()
if err != nil {
t.Fatal(err)
}
if entity.ID != "test-id" {
t.Errorf("expected test-id, got %s", entity.ID)
}
if i.HasNext() {
t.Fatal("unexpected next")
}
if !rows.closeCalled {
t.Error("expected close to be called")
}
})
}

func TestResultIterator_Err(t *testing.T) {
t.Run("query error", func(t *testing.T) {
i := &resultIterator[testEntity]{
queryError: errors.New("query error"),
}
if err := i.Err(); err == nil || err.Error() != "query error" {
t.Errorf("expected query error, got %v", err)
}
})
t.Run("rows error", func(t *testing.T) {
rows := &mockRows{
err: errors.New("rows error"),
}
i := &resultIterator[testEntity]{
rows: rows,
}
if err := i.Err(); err == nil || err.Error() != "rows error" {
t.Errorf("expected rows error, got %v", err)
}
})
}

func TestResultIterator_GetParams(t *testing.T) {
params := []any{"param1", 2}
i := &resultIterator[testEntity]{
params: params,
}
p := i.GetParams()
if len(p) != 2 {
t.Errorf("expected 2 params, got %d", len(p))
}
if p[1] != "param1" {
t.Errorf("expected param1, got %v", p[1])
}
if p[2] != 2 {
t.Errorf("expected 2, got %v", p[2])
}
}

// mockColumnAccessor is a mock for the ColumnAccessor interface
type mockColumnAccessor struct {
fieldValueToScan func(entity any) any
}

func (m mockColumnAccessor) SQLOn(w WriteContext) {}
func (m mockColumnAccessor) Name() string { return "" }
func (m mockColumnAccessor) ValueToInsert() any { return nil }
func (m mockColumnAccessor) Column() ColumnInfo { return ColumnInfo{} }
func (m mockColumnAccessor) FieldValueToScan(entity any) any { return m.fieldValueToScan(entity) }
func (m mockColumnAccessor) AppendScannable(list []any) []any { return list }
func (m mockColumnAccessor) Get(values map[string]any) any { return nil }
func (m mockColumnAccessor) SetSource(parameterIndex int) string { return "" }

// mockRows is a mock for the pgx.Rows interface
type mockRows struct {
nextResult bool
scanError error
closeCalled bool
err error
}

func (m *mockRows) Close() { m.closeCalled = true }
func (m *mockRows) Err() error { return m.err }
func (m *mockRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
func (m *mockRows) FieldDescriptions() []pgconn.FieldDescription { return nil }
func (m *mockRows) Next() bool { return m.nextResult }
func (m *mockRows) Scan(dest ...any) error {
if m.scanError != nil {
return m.scanError
}
// simulate scanning a value
if len(dest) > 0 {
if id, ok := dest[0].(*string); ok {
*id = "test-id"
}
}
m.nextResult = false // only one row
return nil
}
func (m *mockRows) RawValues() [][]byte { return nil }
func (m *mockRows) Conn() *pgx.Conn { return nil }
func (m *mockRows) Values() ([]any, error) { return nil, nil }

// testEntity is a simple struct for testing
type testEntity struct {
ID string
}
25 changes: 23 additions & 2 deletions mutationset.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,31 @@ func (m MutationSet[T]) Exec(ctx context.Context, conn querier, parameters ...*Q
return &resultIterator[T]{queryError: err, commandTag: ct, params: params}
}
rows, err := conn.Query(ctx, query, params...)
if err == nil && !m.canProduceResults() {
if err != nil {
return &resultIterator[T]{queryError: err}
}
if !m.canProduceResults() {
rows.Close()
}
return &resultIterator[T]{queryError: err, rows: rows, selectors: m.returning, params: params}
// order the selectors once
fds := rows.FieldDescriptions()
ordered := make([]ColumnAccessor, len(fds))

// create a map for faster lookup
selectorMap := make(map[string]ColumnAccessor, len(m.returning))
for _, sel := range m.returning {
selectorMap[sel.Column().columnName] = sel
}

for i, fd := range fds {
sel, ok := selectorMap[fd.Name]
if !ok {
// this should not happen
return &resultIterator[T]{queryError: fmt.Errorf("selector not found for column %s", fd.Name)}
}
ordered[i] = sel
}
return &resultIterator[T]{rows: rows, orderedSelectors: ordered, params: params}
}

// valuesToInsert returns the parameters values for the mutation query.
Expand Down
29 changes: 25 additions & 4 deletions queryset.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,32 @@ func (q QuerySet[T]) Exists() unaryExpression {
func (d QuerySet[T]) Iterate(ctx context.Context, conn querier, parameters ...*QueryParameter) (*resultIterator[T], error) {
params := argumentValues(parameters)
rows, err := conn.Query(ctx, SQL(d), params...)
if err != nil {
return nil, err
}
// order the selectors once
fds := rows.FieldDescriptions()
ordered := make([]ColumnAccessor, len(fds))

// create a map for faster lookup
selectorMap := make(map[string]ColumnAccessor, len(d.selectors))
for _, sel := range d.selectors {
selectorMap[sel.Column().columnName] = sel
}

for i, fd := range fds {
sel, ok := selectorMap[fd.Name]
if !ok {
return nil, fmt.Errorf("no selector found for field '%s'", fd.Name)
}
ordered[i] = sel
}

return &resultIterator[T]{
queryError: err,
rows: rows,
selectors: d.selectors,
params: params,
queryError: err,
rows: rows,
orderedSelectors: ordered,
params: params,
}, err
}

Expand Down
32 changes: 30 additions & 2 deletions sql_writing.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ func (w wc) TableAlias(tableName, defaultAlias string) string {
return defaultAlias
}

// onelineWriter is a writer that replaces newlines and tabs with spaces,
// and collapses multiple spaces into a single space.
type onelineWriter struct {
b *strings.Builder
lastCharWasSpace bool
}

func newOnelineWriter(b *strings.Builder) *onelineWriter {
return &onelineWriter{b: b, lastCharWasSpace: true}
}

func (w *onelineWriter) Write(p []byte) (n int, err error) {
for _, b := range p {
if b == '\n' || b == '\t' || b == ' ' {
if !w.lastCharWasSpace {
w.b.WriteByte(' ')
w.lastCharWasSpace = true
}
} else {
w.b.WriteByte(b)
w.lastCharWasSpace = false
}
}
return len(p), nil
}

// IndentedSQL returns source with tabs and lines trying to have a formatted view.
func IndentedSQL(some SQLWriter) string {
buf := new(bytes.Buffer)
Expand All @@ -58,6 +84,8 @@ func IndentedSQL(some SQLWriter) string {

// SQL returns source as a oneliner without tabs or line ends.
func SQL(some SQLWriter) string {
src := IndentedSQL(some)
return strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(src, "\t", " "), "\n", " "), " ", " ")
var b strings.Builder
w := newOnelineWriter(&b)
some.SQLOn(NewWriteContext(w))
Comment thread
emicklei marked this conversation as resolved.
return strings.TrimRight(b.String(), " ")
}