diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index ea97f05..e912800 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -17,3 +17,9 @@ jobs: - name: Test run: go test -v ./... + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: emicklei/pgtalk diff --git a/README.md b/README.md index 56aa81d..2e67570 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # pgtalk +[![Go](https://github.com/emicklei/pgtalk/actions/workflows/go-test.yml/badge.svg)](https://github.com/emicklei/pgtalk/actions/workflows/go-test.yml) +[![Go Report Card](https://goreportcard.com/badge/github.com/emicklei/pgtalk)](https://goreportcard.com/report/github.com/emicklei/pgtalk) [![GoDoc](https://pkg.go.dev/badge/github.com/emicklei/pgtalk)](https://pkg.go.dev/github.com/emicklei/pgtalk) +[![codecov](https://codecov.io/gh/emicklei/pgtalk/branch/master/graph/badge.svg)](https://codecov.io/gh/emicklei/pgtalk) More type safe SQL query building and execution using Go code generated (pgtalk-gen) from PostgreSQL table definitions. After code generation, you get a Go type for each table or view with functions to create a QuerySet or MutationSet value. diff --git a/iterator.go b/iterator.go index 4288027..fc58e1a 100644 --- a/iterator.go +++ b/iterator.go @@ -6,11 +6,11 @@ import ( ) type resultIterator[T any] struct { - queryError error - commandTag pgconn.CommandTag - rows pgx.Rows - selectors []ColumnAccessor - params []any + queryError error + commandTag pgconn.CommandTag + rows pgx.Rows + orderedSelectors []ColumnAccessor + params []any } // Close closes the rows, making the connection ready for use again. It is safe @@ -55,15 +55,9 @@ func (i *resultIterator[T]) HasNext() bool { func (i *resultIterator[T]) Next() (*T, error) { entity := new(T) - list := i.rows.FieldDescriptions() - // order of list is not the same as selectors? toScan := []any{} - for _, each := range list { - for _, other := range i.selectors { - if other.Column().columnName == each.Name { - toScan = append(toScan, other.FieldValueToScan(entity)) - } - } + for _, each := range i.orderedSelectors { + toScan = append(toScan, each.FieldValueToScan(entity)) } if err := i.rows.Scan(toScan...); err != nil { return nil, err diff --git a/iterator_test.go b/iterator_test.go new file mode 100644 index 0000000..e6c4ed4 --- /dev/null +++ b/iterator_test.go @@ -0,0 +1,134 @@ +package pgtalk + +import ( + "errors" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +func TestResultIterator_Next(t *testing.T) { + t.Run("happy flow", func(t *testing.T) { + rows := &mockRows{ + nextResult: true, + scanError: nil, + closeCalled: false, + } + selectors := []ColumnAccessor{ + mockColumnAccessor{ + fieldValueToScan: func(entity any) any { + return &entity.(*testEntity).ID + }, + }, + } + i := &resultIterator[testEntity]{ + rows: rows, + orderedSelectors: selectors, + } + if !i.HasNext() { + t.Fatal("expected next") + } + entity, err := i.Next() + if err != nil { + t.Fatal(err) + } + if entity.ID != "test-id" { + t.Errorf("expected test-id, got %s", entity.ID) + } + if i.HasNext() { + t.Fatal("unexpected next") + } + if !rows.closeCalled { + t.Error("expected close to be called") + } + }) +} + +func TestResultIterator_Err(t *testing.T) { + t.Run("query error", func(t *testing.T) { + i := &resultIterator[testEntity]{ + queryError: errors.New("query error"), + } + if err := i.Err(); err == nil || err.Error() != "query error" { + t.Errorf("expected query error, got %v", err) + } + }) + t.Run("rows error", func(t *testing.T) { + rows := &mockRows{ + err: errors.New("rows error"), + } + i := &resultIterator[testEntity]{ + rows: rows, + } + if err := i.Err(); err == nil || err.Error() != "rows error" { + t.Errorf("expected rows error, got %v", err) + } + }) +} + +func TestResultIterator_GetParams(t *testing.T) { + params := []any{"param1", 2} + i := &resultIterator[testEntity]{ + params: params, + } + p := i.GetParams() + if len(p) != 2 { + t.Errorf("expected 2 params, got %d", len(p)) + } + if p[1] != "param1" { + t.Errorf("expected param1, got %v", p[1]) + } + if p[2] != 2 { + t.Errorf("expected 2, got %v", p[2]) + } +} + +// mockColumnAccessor is a mock for the ColumnAccessor interface +type mockColumnAccessor struct { + fieldValueToScan func(entity any) any +} + +func (m mockColumnAccessor) SQLOn(w WriteContext) {} +func (m mockColumnAccessor) Name() string { return "" } +func (m mockColumnAccessor) ValueToInsert() any { return nil } +func (m mockColumnAccessor) Column() ColumnInfo { return ColumnInfo{} } +func (m mockColumnAccessor) FieldValueToScan(entity any) any { return m.fieldValueToScan(entity) } +func (m mockColumnAccessor) AppendScannable(list []any) []any { return list } +func (m mockColumnAccessor) Get(values map[string]any) any { return nil } +func (m mockColumnAccessor) SetSource(parameterIndex int) string { return "" } + +// mockRows is a mock for the pgx.Rows interface +type mockRows struct { + nextResult bool + scanError error + closeCalled bool + err error +} + +func (m *mockRows) Close() { m.closeCalled = true } +func (m *mockRows) Err() error { return m.err } +func (m *mockRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (m *mockRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (m *mockRows) Next() bool { return m.nextResult } +func (m *mockRows) Scan(dest ...any) error { + if m.scanError != nil { + return m.scanError + } + // simulate scanning a value + if len(dest) > 0 { + if id, ok := dest[0].(*string); ok { + *id = "test-id" + } + } + m.nextResult = false // only one row + return nil +} +func (m *mockRows) RawValues() [][]byte { return nil } +func (m *mockRows) Conn() *pgx.Conn { return nil } +func (m *mockRows) Values() ([]any, error) { return nil, nil } + +// testEntity is a simple struct for testing +type testEntity struct { + ID string +} diff --git a/mutationset.go b/mutationset.go index 967374b..4aba9e1 100644 --- a/mutationset.go +++ b/mutationset.go @@ -103,10 +103,31 @@ func (m MutationSet[T]) Exec(ctx context.Context, conn querier, parameters ...*Q return &resultIterator[T]{queryError: err, commandTag: ct, params: params} } rows, err := conn.Query(ctx, query, params...) - if err == nil && !m.canProduceResults() { + if err != nil { + return &resultIterator[T]{queryError: err} + } + if !m.canProduceResults() { rows.Close() } - return &resultIterator[T]{queryError: err, rows: rows, selectors: m.returning, params: params} + // order the selectors once + fds := rows.FieldDescriptions() + ordered := make([]ColumnAccessor, len(fds)) + + // create a map for faster lookup + selectorMap := make(map[string]ColumnAccessor, len(m.returning)) + for _, sel := range m.returning { + selectorMap[sel.Column().columnName] = sel + } + + for i, fd := range fds { + sel, ok := selectorMap[fd.Name] + if !ok { + // this should not happen + return &resultIterator[T]{queryError: fmt.Errorf("selector not found for column %s", fd.Name)} + } + ordered[i] = sel + } + return &resultIterator[T]{rows: rows, orderedSelectors: ordered, params: params} } // valuesToInsert returns the parameters values for the mutation query. diff --git a/queryset.go b/queryset.go index 85b89f3..2f4e7b4 100644 --- a/queryset.go +++ b/queryset.go @@ -131,11 +131,32 @@ func (q QuerySet[T]) Exists() unaryExpression { func (d QuerySet[T]) Iterate(ctx context.Context, conn querier, parameters ...*QueryParameter) (*resultIterator[T], error) { params := argumentValues(parameters) rows, err := conn.Query(ctx, SQL(d), params...) + if err != nil { + return nil, err + } + // order the selectors once + fds := rows.FieldDescriptions() + ordered := make([]ColumnAccessor, len(fds)) + + // create a map for faster lookup + selectorMap := make(map[string]ColumnAccessor, len(d.selectors)) + for _, sel := range d.selectors { + selectorMap[sel.Column().columnName] = sel + } + + for i, fd := range fds { + sel, ok := selectorMap[fd.Name] + if !ok { + return nil, fmt.Errorf("no selector found for field '%s'", fd.Name) + } + ordered[i] = sel + } + return &resultIterator[T]{ - queryError: err, - rows: rows, - selectors: d.selectors, - params: params, + queryError: err, + rows: rows, + orderedSelectors: ordered, + params: params, }, err } diff --git a/sql_writing.go b/sql_writing.go index 7aba5e9..0d7dd98 100644 --- a/sql_writing.go +++ b/sql_writing.go @@ -49,6 +49,32 @@ func (w wc) TableAlias(tableName, defaultAlias string) string { return defaultAlias } +// onelineWriter is a writer that replaces newlines and tabs with spaces, +// and collapses multiple spaces into a single space. +type onelineWriter struct { + b *strings.Builder + lastCharWasSpace bool +} + +func newOnelineWriter(b *strings.Builder) *onelineWriter { + return &onelineWriter{b: b, lastCharWasSpace: true} +} + +func (w *onelineWriter) Write(p []byte) (n int, err error) { + for _, b := range p { + if b == '\n' || b == '\t' || b == ' ' { + if !w.lastCharWasSpace { + w.b.WriteByte(' ') + w.lastCharWasSpace = true + } + } else { + w.b.WriteByte(b) + w.lastCharWasSpace = false + } + } + return len(p), nil +} + // IndentedSQL returns source with tabs and lines trying to have a formatted view. func IndentedSQL(some SQLWriter) string { buf := new(bytes.Buffer) @@ -58,6 +84,8 @@ func IndentedSQL(some SQLWriter) string { // SQL returns source as a oneliner without tabs or line ends. func SQL(some SQLWriter) string { - src := IndentedSQL(some) - return strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(src, "\t", " "), "\n", " "), " ", " ") + var b strings.Builder + w := newOnelineWriter(&b) + some.SQLOn(NewWriteContext(w)) + return strings.TrimRight(b.String(), " ") }