From 6b483a3b50206b9cf5519963b7505146469f0246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 00:12:21 +0200 Subject: [PATCH 1/8] feat(contract): add pagination utilities Add offset-based and cursor-based pagination support: - Page[T] struct with NewPage constructor (computes LastPage, clamps inputs) - Cursor[T] struct with NewCursor (built-in base64-JSON encoding) and NewCursorWith (custom CursorEncoder interface) - CursorValue/CursorValueWith for decoding cursor strings - Request helpers: Pagination, PaginationWith, CursorPagination, CursorPaginationWith (query param parsing with defaults and clamping) - CursorEncoder mock for testing 100% test coverage on all new code. --- contract/mock/pagination.go | 146 +++++++++++++ contract/pagination.go | 219 ++++++++++++++++++++ contract/pagination_test.go | 305 ++++++++++++++++++++++++++++ contract/request/pagination.go | 61 ++++++ contract/request/pagination_test.go | 195 ++++++++++++++++++ 5 files changed, 926 insertions(+) create mode 100644 contract/mock/pagination.go create mode 100644 contract/pagination.go create mode 100644 contract/pagination_test.go create mode 100644 contract/request/pagination.go create mode 100644 contract/request/pagination_test.go diff --git a/contract/mock/pagination.go b/contract/mock/pagination.go new file mode 100644 index 0000000..6e1c8a2 --- /dev/null +++ b/contract/mock/pagination.go @@ -0,0 +1,146 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mock + +import ( + mock "github.com/stretchr/testify/mock" +) + +// NewCursorEncoderMock creates a new instance of CursorEncoderMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCursorEncoderMock(t interface { + mock.TestingT + Cleanup(func()) +}) *CursorEncoderMock { + mock := &CursorEncoderMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// CursorEncoderMock is an autogenerated mock type for the CursorEncoder type +type CursorEncoderMock struct { + mock.Mock +} + +type CursorEncoderMock_Expecter struct { + mock *mock.Mock +} + +func (_m *CursorEncoderMock) EXPECT() *CursorEncoderMock_Expecter { + return &CursorEncoderMock_Expecter{mock: &_m.Mock} +} + +// Encode provides a mock function for the type CursorEncoderMock +func (_mock *CursorEncoderMock) Encode(value any) (string, error) { + ret := _mock.Called(value) + + if len(ret) == 0 { + panic("no return value specified for Encode") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(any) (string, error)); ok { + return returnFunc(value) + } + if returnFunc, ok := ret.Get(0).(func(any) string); ok { + r0 = returnFunc(value) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(any) error); ok { + r1 = returnFunc(value) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// CursorEncoderMock_Encode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encode' +type CursorEncoderMock_Encode_Call struct { + *mock.Call +} + +// Encode is a helper method to define mock.On call +// - value any +func (_e *CursorEncoderMock_Expecter) Encode(value interface{}) *CursorEncoderMock_Encode_Call { + return &CursorEncoderMock_Encode_Call{Call: _e.mock.On("Encode", value)} +} + +func (_c *CursorEncoderMock_Encode_Call) Run(run func(value any)) *CursorEncoderMock_Encode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0]) + }) + return _c +} + +func (_c *CursorEncoderMock_Encode_Call) Return(_a0 string, _a1 error) *CursorEncoderMock_Encode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CursorEncoderMock_Encode_Call) RunAndReturn(run func(any) (string, error)) *CursorEncoderMock_Encode_Call { + _c.Call.Return(run) + return _c +} + +// Decode provides a mock function for the type CursorEncoderMock +func (_mock *CursorEncoderMock) Decode(cursor string) (any, error) { + ret := _mock.Called(cursor) + + if len(ret) == 0 { + panic("no return value specified for Decode") + } + + var r0 any + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (any, error)); ok { + return returnFunc(cursor) + } + if returnFunc, ok := ret.Get(0).(func(string) any); ok { + r0 = returnFunc(cursor) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0) + } + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(cursor) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// CursorEncoderMock_Decode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decode' +type CursorEncoderMock_Decode_Call struct { + *mock.Call +} + +// Decode is a helper method to define mock.On call +// - cursor string +func (_e *CursorEncoderMock_Expecter) Decode(cursor interface{}) *CursorEncoderMock_Decode_Call { + return &CursorEncoderMock_Decode_Call{Call: _e.mock.On("Decode", cursor)} +} + +func (_c *CursorEncoderMock_Decode_Call) Run(run func(cursor string)) *CursorEncoderMock_Decode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *CursorEncoderMock_Decode_Call) Return(_a0 any, _a1 error) *CursorEncoderMock_Decode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CursorEncoderMock_Decode_Call) RunAndReturn(run func(string) (any, error)) *CursorEncoderMock_Decode_Call { + _c.Call.Return(run) + return _c +} diff --git a/contract/pagination.go b/contract/pagination.go new file mode 100644 index 0000000..c1b9c81 --- /dev/null +++ b/contract/pagination.go @@ -0,0 +1,219 @@ +package contract + +import ( + "encoding/base64" + "encoding/json" + "errors" +) + +// ErrCursorEncode is returned when a cursor value fails to encode. +var ErrCursorEncode = errors.New("failed to encode cursor") + +// ErrCursorDecode is returned when a cursor string fails to decode. +var ErrCursorDecode = errors.New("failed to decode cursor") + +// CursorEncoder defines the encoding and decoding of cursor values +// into opaque cursor strings. Implementations control how cursor +// values are serialized for transport and deserialized on receipt. +type CursorEncoder interface { + // Encode converts a cursor value into an opaque cursor string. + Encode(value any) (string, error) + + // Decode converts an opaque cursor string back into a cursor value. + Decode(cursor string) (any, error) +} + +// Page represents an offset-based paginated result set. +type Page[T any] struct { + Items []T `json:"items"` + Total int64 `json:"total"` + PerPage int `json:"per_page"` + CurrentPage int `json:"current_page"` + LastPage int `json:"last_page"` +} + +// Cursor represents a cursor-based paginated result set. +type Cursor[T any] struct { + Items []T `json:"items"` + PerPage int `json:"per_page"` + NextCursor string `json:"next_cursor,omitempty"` + PrevCursor string `json:"prev_cursor,omitempty"` +} + +// NewPage creates a new [Page] from the given items, total count, +// current page number, and items per page. It computes the last +// page automatically. The current page is clamped to [1, LastPage]. +func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { + if perPage < 1 { + perPage = 1 + } + + lastPage := int((total + int64(perPage) - 1) / int64(perPage)) + + if lastPage < 1 { + lastPage = 1 + } + + if page < 1 { + page = 1 + } + + if page > lastPage { + page = lastPage + } + + if items == nil { + items = []T{} + } + + return Page[T]{ + Items: items, + Total: total, + PerPage: perPage, + CurrentPage: page, + LastPage: lastPage, + } +} + +// NewCursor creates a new [Cursor] from the given items using a +// built-in base64-JSON encoding for cursor values. The extract +// function determines which value from each item becomes the cursor. +// When hasNext is true, the last item's extracted value becomes the +// next cursor. When hasPrev is true, the first item's extracted +// value becomes the previous cursor. +func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, extract func(T) any) (Cursor[T], error) { + if items == nil { + items = []T{} + } + + result := Cursor[T]{ + Items: items, + PerPage: perPage, + } + + if len(items) == 0 { + return result, nil + } + + if hasNext { + encoded, err := base64JSONEncode(extract(items[len(items)-1])) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.NextCursor = encoded + } + + if hasPrev { + encoded, err := base64JSONEncode(extract(items[0])) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.PrevCursor = encoded + } + + return result, nil +} + +// NewCursorWith creates a new [Cursor] from the given items using +// a custom [CursorEncoder]. When hasNext is true, the last item is +// passed to the encoder to produce the next cursor. When hasPrev is +// true, the first item is passed to produce the previous cursor. +func NewCursorWith[T any](items []T, perPage int, hasNext, hasPrev bool, encoder CursorEncoder) (Cursor[T], error) { + if items == nil { + items = []T{} + } + + result := Cursor[T]{ + Items: items, + PerPage: perPage, + } + + if len(items) == 0 { + return result, nil + } + + if hasNext { + encoded, err := encoder.Encode(items[len(items)-1]) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.NextCursor = encoded + } + + if hasPrev { + encoded, err := encoder.Encode(items[0]) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.PrevCursor = encoded + } + + return result, nil +} + +// CursorValue decodes a cursor string produced by [NewCursor] back +// into the original cursor value using the built-in base64-JSON encoding. +func CursorValue[T any](cursor string) (T, error) { + var zero T + + raw, err := base64JSONDecode(cursor) + + if err != nil { + return zero, errors.Join(ErrCursorDecode, err) + } + + value, ok := raw.(T) + + if !ok { + return zero, ErrCursorDecode + } + + return value, nil +} + +// CursorValueWith decodes a cursor string using a custom [CursorEncoder]. +func CursorValueWith(cursor string, encoder CursorEncoder) (any, error) { + value, err := encoder.Decode(cursor) + + if err != nil { + return nil, errors.Join(ErrCursorDecode, err) + } + + return value, nil +} + +// base64JSONEncode encodes a value as JSON then base64url. +func base64JSONEncode(value any) (string, error) { + data, err := json.Marshal(value) + + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(data), nil +} + +// base64JSONDecode decodes a base64url string then unmarshals as JSON. +func base64JSONDecode(cursor string) (any, error) { + data, err := base64.RawURLEncoding.DecodeString(cursor) + + if err != nil { + return nil, err + } + + var value any + + if err := json.Unmarshal(data, &value); err != nil { + return nil, err + } + + return value, nil +} diff --git a/contract/pagination_test.go b/contract/pagination_test.go new file mode 100644 index 0000000..9d57bca --- /dev/null +++ b/contract/pagination_test.go @@ -0,0 +1,305 @@ +package contract_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/studiolambda/cosmos/contract" +) + +func TestNewPageComputesLastPage(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{"a", "b"}, 10, 1, 5) + + require.Equal(t, 2, page.LastPage) +} + +func TestNewPageComputesLastPageWithRemainder(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{"a", "b"}, 11, 1, 5) + + require.Equal(t, 3, page.LastPage) +} + +func TestNewPageClampsPageBelowOne(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{"a"}, 10, 0, 5) + + require.Equal(t, 1, page.CurrentPage) +} + +func TestNewPageClampsPageAboveLastPage(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{}, 10, 99, 5) + + require.Equal(t, 2, page.CurrentPage) +} + +func TestNewPageClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{"a"}, 5, 1, 0) + + require.Equal(t, 1, page.PerPage) +} + +func TestNewPageZeroTotalSetsLastPageOne(t *testing.T) { + t.Parallel() + + page := contract.NewPage([]string{}, 0, 1, 10) + + require.Equal(t, 1, page.LastPage) + require.Equal(t, 1, page.CurrentPage) +} + +func TestNewPageNilItemsBecomesEmptySlice(t *testing.T) { + t.Parallel() + + page := contract.NewPage[string](nil, 0, 1, 10) + + require.NotNil(t, page.Items) + require.Empty(t, page.Items) +} + +func TestNewPagePreservesItems(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + page := contract.NewPage(items, 100, 3, 10) + + require.Equal(t, items, page.Items) + require.Equal(t, int64(100), page.Total) + require.Equal(t, 3, page.CurrentPage) + require.Equal(t, 10, page.PerPage) +} + +func TestNewCursorEncodesNextCursor(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.NewCursor(items, 3, true, false, func(item int) any { return item }) + + require.NoError(t, err) + require.NotEmpty(t, cursor.NextCursor) + require.Empty(t, cursor.PrevCursor) +} + +func TestNewCursorEncodesPrevCursor(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.NewCursor(items, 3, false, true, func(item int) any { return item }) + + require.NoError(t, err) + require.Empty(t, cursor.NextCursor) + require.NotEmpty(t, cursor.PrevCursor) +} + +func TestNewCursorEncodesBothCursors(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.NewCursor(items, 3, true, true, func(item int) any { return item }) + + require.NoError(t, err) + require.NotEmpty(t, cursor.NextCursor) + require.NotEmpty(t, cursor.PrevCursor) +} + +func TestNewCursorEmptyItemsNoCursors(t *testing.T) { + t.Parallel() + + cursor, err := contract.NewCursor([]int{}, 10, true, true, func(item int) any { return item }) + + require.NoError(t, err) + require.Empty(t, cursor.NextCursor) + require.Empty(t, cursor.PrevCursor) +} + +func TestNewCursorNilItemsBecomesEmptySlice(t *testing.T) { + t.Parallel() + + cursor, err := contract.NewCursor[int](nil, 10, false, false, func(item int) any { return item }) + + require.NoError(t, err) + require.NotNil(t, cursor.Items) + require.Empty(t, cursor.Items) +} + +func TestNewCursorPreservesPerPage(t *testing.T) { + t.Parallel() + + cursor, err := contract.NewCursor([]int{1}, 25, false, false, func(item int) any { return item }) + + require.NoError(t, err) + require.Equal(t, 25, cursor.PerPage) +} + +func TestNewCursorEncodeErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + items := []int{1} + + // Channels cannot be JSON-encoded. + _, err := contract.NewCursor(items, 10, true, false, func(item int) any { + return make(chan int) + }) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestNewCursorPrevEncodeErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + items := []int{1} + + _, err := contract.NewCursor(items, 10, false, true, func(item int) any { + return make(chan int) + }) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestCursorValueRoundTrip(t *testing.T) { + t.Parallel() + + items := []int{10, 20, 30} + cursor, err := contract.NewCursor(items, 3, true, false, func(item int) any { return item }) + + require.NoError(t, err) + + value, err := contract.CursorValue[float64](cursor.NextCursor) + + require.NoError(t, err) + require.Equal(t, float64(30), value) +} + +func TestCursorValueInvalidBase64ReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + _, err := contract.CursorValue[int]("not-valid-base64!!!") + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} + +func TestCursorValueTypeMismatchReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + items := []string{"hello"} + cursor, err := contract.NewCursor(items, 1, true, false, func(item string) any { return item }) + + require.NoError(t, err) + + _, err = contract.CursorValue[int](cursor.NextCursor) + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} + +type failEncoder struct{} + +func (failEncoder) Encode(value any) (string, error) { + return "", errors.New("encode failed") +} + +func (failEncoder) Decode(cursor string) (any, error) { + return nil, errors.New("decode failed") +} + +type idEncoder struct{} + +func (idEncoder) Encode(value any) (string, error) { + return "custom-cursor", nil +} + +func (idEncoder) Decode(cursor string) (any, error) { + return cursor, nil +} + +func TestNewCursorWithCustomEncoder(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.NewCursorWith(items, 3, true, false, idEncoder{}) + + require.NoError(t, err) + require.Equal(t, "custom-cursor", cursor.NextCursor) +} + +func TestNewCursorWithEncoderErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + items := []int{1} + _, err := contract.NewCursorWith(items, 1, true, false, failEncoder{}) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestNewCursorWithPrevEncoderErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + items := []int{1} + _, err := contract.NewCursorWith(items, 1, false, true, failEncoder{}) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestNewCursorWithEmptyItems(t *testing.T) { + t.Parallel() + + cursor, err := contract.NewCursorWith([]int{}, 10, true, true, idEncoder{}) + + require.NoError(t, err) + require.Empty(t, cursor.NextCursor) + require.Empty(t, cursor.PrevCursor) +} + +func TestNewCursorWithNilItemsBecomesEmptySlice(t *testing.T) { + t.Parallel() + + cursor, err := contract.NewCursorWith[int](nil, 10, false, false, idEncoder{}) + + require.NoError(t, err) + require.NotNil(t, cursor.Items) +} + +func TestNewCursorWithEncodesBothCursors(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.NewCursorWith(items, 3, true, true, idEncoder{}) + + require.NoError(t, err) + require.Equal(t, "custom-cursor", cursor.NextCursor) + require.Equal(t, "custom-cursor", cursor.PrevCursor) +} + +func TestCursorValueInvalidJSONReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + // Valid base64 but invalid JSON. + _, err := contract.CursorValue[int]("bm90LWpzb24") + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} + +func TestCursorValueWithCustomDecoder(t *testing.T) { + t.Parallel() + + value, err := contract.CursorValueWith("test-cursor", idEncoder{}) + + require.NoError(t, err) + require.Equal(t, "test-cursor", value) +} + +func TestCursorValueWithDecoderErrorReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + _, err := contract.CursorValueWith("anything", failEncoder{}) + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} diff --git a/contract/request/pagination.go b/contract/request/pagination.go new file mode 100644 index 0000000..7e29ba0 --- /dev/null +++ b/contract/request/pagination.go @@ -0,0 +1,61 @@ +package request + +import "net/http" + +// Pagination extracts the page number and per-page count from the +// request query parameters "page" and "per_page". It applies sensible +// defaults: page 1, 25 items per page, and a maximum of 100 items +// per page. Use [PaginationWith] for custom defaults and limits. +func Pagination(r *http.Request) (page, perPage int) { + return PaginationWith(r, 1, 25, 100) +} + +// PaginationWith extracts the page number and per-page count from the +// request query parameters "page" and "per_page" using the provided +// defaults and maximum per-page limit. The page is floored at 1 and +// the per-page is clamped between 1 and maxPerPage. +func PaginationWith(r *http.Request, defaultPage, defaultPerPage, maxPerPage int) (page, perPage int) { + page = QueryIntOr(r, "page", defaultPage) + perPage = QueryIntOr(r, "per_page", defaultPerPage) + + if page < 1 { + page = 1 + } + + if perPage < 1 { + perPage = 1 + } + + if perPage > maxPerPage { + perPage = maxPerPage + } + + return page, perPage +} + +// CursorPagination extracts the cursor string and per-page count from +// the request query parameters "cursor" and "per_page". It applies +// sensible defaults: 25 items per page and a maximum of 100 items per +// page. Use [CursorPaginationWith] for custom defaults and limits. +func CursorPagination(r *http.Request) (cursor string, perPage int) { + return CursorPaginationWith(r, 25, 100) +} + +// CursorPaginationWith extracts the cursor string and per-page count +// from the request query parameters "cursor" and "per_page" using the +// provided defaults and maximum per-page limit. The per-page is clamped +// between 1 and maxPerPage. +func CursorPaginationWith(r *http.Request, defaultPerPage, maxPerPage int) (cursor string, perPage int) { + cursor = Query(r, "cursor") + perPage = QueryIntOr(r, "per_page", defaultPerPage) + + if perPage < 1 { + perPage = 1 + } + + if perPage > maxPerPage { + perPage = maxPerPage + } + + return cursor, perPage +} diff --git a/contract/request/pagination_test.go b/contract/request/pagination_test.go new file mode 100644 index 0000000..6cb6e38 --- /dev/null +++ b/contract/request/pagination_test.go @@ -0,0 +1,195 @@ +package request_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/studiolambda/cosmos/contract/request" +) + +func TestPaginationReturnsDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + page, perPage := request.Pagination(r) + + require.Equal(t, 1, page) + require.Equal(t, 25, perPage) +} + +func TestPaginationParsesQueryParams(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=3&per_page=50", nil) + + page, perPage := request.Pagination(r) + + require.Equal(t, 3, page) + require.Equal(t, 50, perPage) +} + +func TestPaginationClampsPerPageToMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=999", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 100, perPage) +} + +func TestPaginationClampsPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=0", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationClampsNegativePage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=-5", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=0", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 1, perPage) +} + +func TestPaginationWithCustomDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + page, perPage := request.PaginationWith(r, 2, 10, 50) + + require.Equal(t, 2, page) + require.Equal(t, 10, perPage) +} + +func TestPaginationWithCustomMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=100", nil) + + _, perPage := request.PaginationWith(r, 1, 10, 50) + + require.Equal(t, 50, perPage) +} + +func TestPaginationIgnoresInvalidPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=abc", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationIgnoresInvalidPerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=abc", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationReturnsDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + cursor, perPage := request.CursorPagination(r) + + require.Empty(t, cursor) + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationParsesCursor(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?cursor=abc123&per_page=50", nil) + + cursor, perPage := request.CursorPagination(r) + + require.Equal(t, "abc123", cursor) + require.Equal(t, 50, perPage) +} + +func TestCursorPaginationClampsPerPageToMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=999", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 100, perPage) +} + +func TestCursorPaginationClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=0", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 1, perPage) +} + +func TestCursorPaginationWithCustomDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + _, perPage := request.CursorPaginationWith(r, 10, 50) + + require.Equal(t, 10, perPage) +} + +func TestCursorPaginationWithCustomMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=100", nil) + + _, perPage := request.CursorPaginationWith(r, 10, 50) + + require.Equal(t, 50, perPage) +} + +func TestCursorPaginationIgnoresInvalidPerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=abc", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationNegativePerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=-5", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 1, perPage) +} From e7f45e091725bf50d256ecbfd11d1dfbbd4b4bc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 00:16:09 +0200 Subject: [PATCH 2/8] chore(contract): go fix --- contract/pagination.go | 6 +----- contract/session.go | 5 ++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/contract/pagination.go b/contract/pagination.go index c1b9c81..1e58f58 100644 --- a/contract/pagination.go +++ b/contract/pagination.go @@ -48,11 +48,7 @@ func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { perPage = 1 } - lastPage := int((total + int64(perPage) - 1) / int64(perPage)) - - if lastPage < 1 { - lastPage = 1 - } + lastPage := max(int((total+int64(perPage)-1)/int64(perPage)), 1) if page < 1 { page = 1 diff --git a/contract/session.go b/contract/session.go index 337a9c4..2181241 100644 --- a/contract/session.go +++ b/contract/session.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/base64" + "maps" "sync" "time" ) @@ -106,9 +107,7 @@ func (session *Session) All() map[string]any { result := make(map[string]any, len(session.storage)) - for k, v := range session.storage { - result[k] = v - } + maps.Copy(result, session.storage) return result } From a9645eca328b1804bf031be8b8c7d9372f84d5ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 00:20:44 +0200 Subject: [PATCH 3/8] refactor(contract): use min/max builtins for pagination clamping --- contract/request/pagination.go | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/contract/request/pagination.go b/contract/request/pagination.go index 7e29ba0..083e8cd 100644 --- a/contract/request/pagination.go +++ b/contract/request/pagination.go @@ -15,20 +15,8 @@ func Pagination(r *http.Request) (page, perPage int) { // defaults and maximum per-page limit. The page is floored at 1 and // the per-page is clamped between 1 and maxPerPage. func PaginationWith(r *http.Request, defaultPage, defaultPerPage, maxPerPage int) (page, perPage int) { - page = QueryIntOr(r, "page", defaultPage) - perPage = QueryIntOr(r, "per_page", defaultPerPage) - - if page < 1 { - page = 1 - } - - if perPage < 1 { - perPage = 1 - } - - if perPage > maxPerPage { - perPage = maxPerPage - } + page = max(QueryIntOr(r, "page", defaultPage), 1) + perPage = min(max(QueryIntOr(r, "per_page", defaultPerPage), 1), maxPerPage) return page, perPage } @@ -47,15 +35,7 @@ func CursorPagination(r *http.Request) (cursor string, perPage int) { // between 1 and maxPerPage. func CursorPaginationWith(r *http.Request, defaultPerPage, maxPerPage int) (cursor string, perPage int) { cursor = Query(r, "cursor") - perPage = QueryIntOr(r, "per_page", defaultPerPage) - - if perPage < 1 { - perPage = 1 - } - - if perPage > maxPerPage { - perPage = maxPerPage - } + perPage = min(max(QueryIntOr(r, "per_page", defaultPerPage), 1), maxPerPage) return cursor, perPage } From 0f3fbf2dd7a9e47aaf73f7d7ed9a63ecb46e03ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 00:23:16 +0200 Subject: [PATCH 4/8] refactor(contract): use min/max builtins in NewPage --- contract/pagination.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/contract/pagination.go b/contract/pagination.go index 1e58f58..a8a0d34 100644 --- a/contract/pagination.go +++ b/contract/pagination.go @@ -44,19 +44,9 @@ type Cursor[T any] struct { // current page number, and items per page. It computes the last // page automatically. The current page is clamped to [1, LastPage]. func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { - if perPage < 1 { - perPage = 1 - } - + perPage = max(perPage, 1) lastPage := max(int((total+int64(perPage)-1)/int64(perPage)), 1) - - if page < 1 { - page = 1 - } - - if page > lastPage { - page = lastPage - } + page = min(max(page, 1), lastPage) if items == nil { items = []T{} From 15b93f2170eb1e29b5a055664d152a768e1848eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 00:38:13 +0200 Subject: [PATCH 5/8] refactor(contract): move Hooks from interface to concrete struct Move the Hooks implementation from framework to contract as a concrete struct, mirroring the Session pattern. Only one meaningful implementation exists, so the interface indirection added no value. - contract/hooks.go: interface replaced with mutex-protected struct - contract/request/hooks.go: returns *contract.Hooks instead of interface - framework/hooks_writer.go: embeds *contract.Hooks - framework/handler.go: uses contract.NewHooks() - Deleted contract/mock/hooks.go (use contract.NewHooks() in tests) - Deleted framework/hooks.go and framework/hooks_test.go (moved to contract) BREAKING CHANGE: contract.Hooks is now a concrete *Hooks struct. Type assertions change from .(contract.Hooks) to .(*contract.Hooks). request.Hooks() and request.TryHooks() now return *contract.Hooks. --- contract/hooks.go | 107 ++++++++--- contract/hooks_test.go | 143 +++++++++++++++ contract/mock/hooks.go | 319 --------------------------------- contract/request/hooks.go | 8 +- contract/request/hooks_test.go | 20 +-- framework/handler.go | 2 +- framework/handler_test.go | 8 +- framework/hooks.go | 102 ----------- framework/hooks_test.go | 139 -------------- framework/hooks_writer.go | 6 +- framework/hooks_writer_test.go | 31 ++-- 11 files changed, 263 insertions(+), 622 deletions(-) delete mode 100644 contract/mock/hooks.go delete mode 100644 framework/hooks.go delete mode 100644 framework/hooks_test.go diff --git a/contract/hooks.go b/contract/hooks.go index 86431ba..62c9897 100644 --- a/contract/hooks.go +++ b/contract/hooks.go @@ -1,6 +1,10 @@ package contract -import "net/http" +import ( + "net/http" + "slices" + "sync" +) // hooksKey is the unexported type used as the context key // for storing and retrieving [Hooks] from a request context. @@ -26,28 +30,91 @@ type BeforeWriteHeaderHook = func(w http.ResponseWriter, status int) // slice that is about to be sent. type BeforeWriteHook = func(w http.ResponseWriter, content []byte) -// Hooks defines the contract for registering and retrieving -// lifecycle callbacks during HTTP request processing. Middleware -// and handlers use these hooks to observe response events. -type Hooks interface { - // AfterResponse registers one or more callbacks to be invoked - // after the HTTP response has been fully written. - AfterResponse(callbacks ...AfterResponseHook) +// Hooks provides lifecycle hook registration for the HTTP +// request/response cycle. Middleware and handlers can attach +// callbacks that fire before headers are written, before the +// body is written, and after the response completes. +// +// All methods are safe for concurrent use. +type Hooks struct { + mutex sync.Mutex + afterResponseHooks []AfterResponseHook + beforeWriteHeaderHooks []BeforeWriteHeaderHook + beforeWriteHooks []BeforeWriteHook +} + +// NewHooks creates a [Hooks] instance with empty callback slices +// ready to accept registrations via the Before* and After* methods. +func NewHooks() *Hooks { + return &Hooks{ + beforeWriteHeaderHooks: []BeforeWriteHeaderHook{}, + beforeWriteHooks: []BeforeWriteHook{}, + afterResponseHooks: []AfterResponseHook{}, + } +} + +// AfterResponse registers one or more callbacks to be invoked +// after the HTTP response has been fully written. +func (hooks *Hooks) AfterResponse(callbacks ...AfterResponseHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.afterResponseHooks = append(hooks.afterResponseHooks, callbacks...) +} + +// AfterResponseFuncs returns a reversed clone of the registered +// AfterResponse callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) AfterResponseFuncs() []AfterResponseHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + clone := slices.Clone(hooks.afterResponseHooks) + slices.Reverse(clone) + + return clone +} + +// BeforeWrite registers one or more callbacks to be invoked +// just before response body bytes are written. +func (hooks *Hooks) BeforeWrite(callbacks ...BeforeWriteHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.beforeWriteHooks = append(hooks.beforeWriteHooks, callbacks...) +} + +// BeforeWriteFuncs returns a reversed clone of the registered +// BeforeWrite callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) BeforeWriteFuncs() []BeforeWriteHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() - // AfterResponseFuncs returns all registered after-response callbacks. - AfterResponseFuncs() []AfterResponseHook + clone := slices.Clone(hooks.beforeWriteHooks) + slices.Reverse(clone) - // BeforeWrite registers one or more callbacks to be invoked - // just before response body bytes are written. - BeforeWrite(callbacks ...BeforeWriteHook) + return clone +} + +// BeforeWriteHeader registers one or more callbacks to be invoked +// just before the response status code is written. +func (hooks *Hooks) BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.beforeWriteHeaderHooks = append(hooks.beforeWriteHeaderHooks, callbacks...) +} - // BeforeWriteFuncs returns all registered before-write callbacks. - BeforeWriteFuncs() []BeforeWriteHook +// BeforeWriteHeaderFuncs returns a reversed clone of the registered +// BeforeWriteHeader callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() - // BeforeWriteHeader registers one or more callbacks to be invoked - // just before the response status code is written. - BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook) + clone := slices.Clone(hooks.beforeWriteHeaderHooks) + slices.Reverse(clone) - // BeforeWriteHeaderFuncs returns all registered before-write-header callbacks. - BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook + return clone } diff --git a/contract/hooks_test.go b/contract/hooks_test.go index 3408d95..d3ffde6 100644 --- a/contract/hooks_test.go +++ b/contract/hooks_test.go @@ -1,6 +1,9 @@ package contract_test import ( + "net/http" + "net/http/httptest" + "sync" "testing" "github.com/stretchr/testify/require" @@ -20,3 +23,143 @@ func TestHooksKeyIsDistinctType(t *testing.T) { require.NotEqual(t, other, contract.HooksKey) } + +func TestNewHooksReturnsNonNil(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + require.NotNil(t, hooks) +} + +func TestHooksAfterResponseRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.AfterResponse(func(err error) { called = true }) + + fns := hooks.AfterResponseFuncs() + + require.Len(t, fns, 1) + + fns[0](nil) + + require.True(t, called) +} + +func TestHooksAfterResponseFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.AfterResponse(func(err error) { order = append(order, 1) }) + hooks.AfterResponse(func(err error) { order = append(order, 2) }) + + for _, fn := range hooks.AfterResponseFuncs() { + fn(nil) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksBeforeWriteRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { called = true }) + + fns := hooks.BeforeWriteFuncs() + + require.Len(t, fns, 1) + + fns[0](httptest.NewRecorder(), nil) + + require.True(t, called) +} + +func TestHooksBeforeWriteFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 1) }) + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 2) }) + + for _, fn := range hooks.BeforeWriteFuncs() { + fn(httptest.NewRecorder(), nil) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksBeforeWriteHeaderRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { called = true }) + + fns := hooks.BeforeWriteHeaderFuncs() + + require.Len(t, fns, 1) + + fns[0](httptest.NewRecorder(), 200) + + require.True(t, called) +} + +func TestHooksBeforeWriteHeaderFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 1) }) + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 2) }) + + for _, fn := range hooks.BeforeWriteHeaderFuncs() { + fn(httptest.NewRecorder(), 200) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksEmptyFuncsReturnsEmptySlice(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + require.Empty(t, hooks.AfterResponseFuncs()) + require.Empty(t, hooks.BeforeWriteFuncs()) + require.Empty(t, hooks.BeforeWriteHeaderFuncs()) +} + +func TestHooksConcurrentAccess(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var wg sync.WaitGroup + + for range 100 { + wg.Add(1) + + go func() { + defer wg.Done() + + hooks.AfterResponse(func(err error) {}) + hooks.AfterResponseFuncs() + }() + } + + wg.Wait() + + require.Len(t, hooks.AfterResponseFuncs(), 100) +} diff --git a/contract/mock/hooks.go b/contract/mock/hooks.go deleted file mode 100644 index ed46244..0000000 --- a/contract/mock/hooks.go +++ /dev/null @@ -1,319 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mock - -import ( - mock "github.com/stretchr/testify/mock" - "github.com/studiolambda/cosmos/contract" -) - -// NewHooksMock creates a new instance of HooksMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewHooksMock(t interface { - mock.TestingT - Cleanup(func()) -}) *HooksMock { - mock := &HooksMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// HooksMock is an autogenerated mock type for the Hooks type -type HooksMock struct { - mock.Mock -} - -type HooksMock_Expecter struct { - mock *mock.Mock -} - -func (_m *HooksMock) EXPECT() *HooksMock_Expecter { - return &HooksMock_Expecter{mock: &_m.Mock} -} - -// AfterResponse provides a mock function for the type HooksMock -func (_mock *HooksMock) AfterResponse(callbacks ...contract.AfterResponseHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_AfterResponse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AfterResponse' -type HooksMock_AfterResponse_Call struct { - *mock.Call -} - -// AfterResponse is a helper method to define mock.On call -// - callbacks ...contract.AfterResponseHook -func (_e *HooksMock_Expecter) AfterResponse(callbacks ...interface{}) *HooksMock_AfterResponse_Call { - return &HooksMock_AfterResponse_Call{Call: _e.mock.On("AfterResponse", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_AfterResponse_Call) Run(run func(callbacks ...contract.AfterResponseHook)) *HooksMock_AfterResponse_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.AfterResponseHook - var variadicArgs []contract.AfterResponseHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.AfterResponseHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_AfterResponse_Call) Return() *HooksMock_AfterResponse_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_AfterResponse_Call) RunAndReturn(run func(callbacks ...contract.AfterResponseHook)) *HooksMock_AfterResponse_Call { - _c.Run(run) - return _c -} - -// AfterResponseFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) AfterResponseFuncs() []contract.AfterResponseHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for AfterResponseFuncs") - } - - var r0 []contract.AfterResponseHook - if returnFunc, ok := ret.Get(0).(func() []contract.AfterResponseHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.AfterResponseHook) - } - } - return r0 -} - -// HooksMock_AfterResponseFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AfterResponseFuncs' -type HooksMock_AfterResponseFuncs_Call struct { - *mock.Call -} - -// AfterResponseFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) AfterResponseFuncs() *HooksMock_AfterResponseFuncs_Call { - return &HooksMock_AfterResponseFuncs_Call{Call: _e.mock.On("AfterResponseFuncs")} -} - -func (_c *HooksMock_AfterResponseFuncs_Call) Run(run func()) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_AfterResponseFuncs_Call) Return(vs []contract.AfterResponseHook) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_AfterResponseFuncs_Call) RunAndReturn(run func() []contract.AfterResponseHook) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Return(run) - return _c -} - -// BeforeWrite provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWrite(callbacks ...contract.BeforeWriteHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_BeforeWrite_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWrite' -type HooksMock_BeforeWrite_Call struct { - *mock.Call -} - -// BeforeWrite is a helper method to define mock.On call -// - callbacks ...contract.BeforeWriteHook -func (_e *HooksMock_Expecter) BeforeWrite(callbacks ...interface{}) *HooksMock_BeforeWrite_Call { - return &HooksMock_BeforeWrite_Call{Call: _e.mock.On("BeforeWrite", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_BeforeWrite_Call) Run(run func(callbacks ...contract.BeforeWriteHook)) *HooksMock_BeforeWrite_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.BeforeWriteHook - var variadicArgs []contract.BeforeWriteHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.BeforeWriteHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_BeforeWrite_Call) Return() *HooksMock_BeforeWrite_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_BeforeWrite_Call) RunAndReturn(run func(callbacks ...contract.BeforeWriteHook)) *HooksMock_BeforeWrite_Call { - _c.Run(run) - return _c -} - -// BeforeWriteFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteFuncs() []contract.BeforeWriteHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for BeforeWriteFuncs") - } - - var r0 []contract.BeforeWriteHook - if returnFunc, ok := ret.Get(0).(func() []contract.BeforeWriteHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.BeforeWriteHook) - } - } - return r0 -} - -// HooksMock_BeforeWriteFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteFuncs' -type HooksMock_BeforeWriteFuncs_Call struct { - *mock.Call -} - -// BeforeWriteFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) BeforeWriteFuncs() *HooksMock_BeforeWriteFuncs_Call { - return &HooksMock_BeforeWriteFuncs_Call{Call: _e.mock.On("BeforeWriteFuncs")} -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) Run(run func()) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) Return(vs []contract.BeforeWriteHook) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) RunAndReturn(run func() []contract.BeforeWriteHook) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Return(run) - return _c -} - -// BeforeWriteHeader provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteHeader(callbacks ...contract.BeforeWriteHeaderHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_BeforeWriteHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteHeader' -type HooksMock_BeforeWriteHeader_Call struct { - *mock.Call -} - -// BeforeWriteHeader is a helper method to define mock.On call -// - callbacks ...contract.BeforeWriteHeaderHook -func (_e *HooksMock_Expecter) BeforeWriteHeader(callbacks ...interface{}) *HooksMock_BeforeWriteHeader_Call { - return &HooksMock_BeforeWriteHeader_Call{Call: _e.mock.On("BeforeWriteHeader", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_BeforeWriteHeader_Call) Run(run func(callbacks ...contract.BeforeWriteHeaderHook)) *HooksMock_BeforeWriteHeader_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.BeforeWriteHeaderHook - var variadicArgs []contract.BeforeWriteHeaderHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.BeforeWriteHeaderHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_BeforeWriteHeader_Call) Return() *HooksMock_BeforeWriteHeader_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_BeforeWriteHeader_Call) RunAndReturn(run func(callbacks ...contract.BeforeWriteHeaderHook)) *HooksMock_BeforeWriteHeader_Call { - _c.Run(run) - return _c -} - -// BeforeWriteHeaderFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for BeforeWriteHeaderFuncs") - } - - var r0 []contract.BeforeWriteHeaderHook - if returnFunc, ok := ret.Get(0).(func() []contract.BeforeWriteHeaderHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.BeforeWriteHeaderHook) - } - } - return r0 -} - -// HooksMock_BeforeWriteHeaderFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteHeaderFuncs' -type HooksMock_BeforeWriteHeaderFuncs_Call struct { - *mock.Call -} - -// BeforeWriteHeaderFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) BeforeWriteHeaderFuncs() *HooksMock_BeforeWriteHeaderFuncs_Call { - return &HooksMock_BeforeWriteHeaderFuncs_Call{Call: _e.mock.On("BeforeWriteHeaderFuncs")} -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) Run(run func()) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) Return(vs []contract.BeforeWriteHeaderHook) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) RunAndReturn(run func() []contract.BeforeWriteHeaderHook) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Return(run) - return _c -} diff --git a/contract/request/hooks.go b/contract/request/hooks.go index 7737acc..00cf465 100644 --- a/contract/request/hooks.go +++ b/contract/request/hooks.go @@ -22,8 +22,8 @@ var ErrNoHooksMiddleware = problem.Problem{ // WARNING: This function panics when hooks are missing. Use // [TryHooks] for a non-panicking alternative, or ensure the // [framework.Recover] middleware is in place. -func Hooks(r *http.Request) contract.Hooks { - if hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks); ok { +func Hooks(r *http.Request) *contract.Hooks { + if hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks); ok { return hooks } @@ -34,8 +34,8 @@ func Hooks(r *http.Request) contract.Hooks { // context without panicking. The boolean return value indicates // whether hooks were found. This is the safe alternative to // [Hooks] for use outside the framework handler chain. -func TryHooks(r *http.Request) (contract.Hooks, bool) { - hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks) +func TryHooks(r *http.Request) (*contract.Hooks, bool) { + hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks) return hooks, ok } diff --git a/contract/request/hooks_test.go b/contract/request/hooks_test.go index 1b6a57e..83d7022 100644 --- a/contract/request/hooks_test.go +++ b/contract/request/hooks_test.go @@ -11,24 +11,12 @@ import ( "github.com/studiolambda/cosmos/contract/request" ) -// stubHooks is a minimal implementation of contract.Hooks for testing. -type stubHooks struct{} - -func (stubHooks) AfterResponse(...contract.AfterResponseHook) {} -func (stubHooks) AfterResponseFuncs() []contract.AfterResponseHook { return nil } -func (stubHooks) BeforeWrite(...contract.BeforeWriteHook) {} -func (stubHooks) BeforeWriteFuncs() []contract.BeforeWriteHook { return nil } -func (stubHooks) BeforeWriteHeader(...contract.BeforeWriteHeaderHook) {} -func (stubHooks) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - return nil -} - func TestHooksReturnsHooksFromContext(t *testing.T) { t.Parallel() - hooks := stubHooks{} + hooks := contract.NewHooks() ctx := context.WithValue( - context.Background(), contract.HooksKey, contract.Hooks(hooks), + context.Background(), contract.HooksKey, hooks, ) r := httptest.NewRequest( http.MethodGet, "/", nil, @@ -65,9 +53,9 @@ func TestHooksPanicsWithErrNoHooksMiddleware(t *testing.T) { func TestTryHooksReturnsTrueWhenPresent(t *testing.T) { t.Parallel() - hooks := stubHooks{} + hooks := contract.NewHooks() ctx := context.WithValue( - context.Background(), contract.HooksKey, contract.Hooks(hooks), + context.Background(), contract.HooksKey, hooks, ) r := httptest.NewRequest( http.MethodGet, "/", nil, diff --git a/framework/handler.go b/framework/handler.go index db2007f..824b499 100644 --- a/framework/handler.go +++ b/framework/handler.go @@ -101,7 +101,7 @@ func (handler Handler) ServeHTTP( w http.ResponseWriter, r *http.Request, ) { - hooks := NewHooks() + hooks := contract.NewHooks() wrapped := NewResponseWriter(w, hooks) ctx := context.WithValue(r.Context(), contract.HooksKey, hooks) err := handler(wrapped, r.WithContext(ctx)) diff --git a/framework/handler_test.go b/framework/handler_test.go index a8a79fe..d1a2232 100644 --- a/framework/handler_test.go +++ b/framework/handler_test.go @@ -159,7 +159,7 @@ func TestServeHTTPAfterResponseHooksRun(t *testing.T) { var receivedErr atomic.Value h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { hookCalled.Store(true) receivedErr.Store(err) @@ -180,7 +180,7 @@ func TestServeHTTPAfterResponseHookPanicRecovered(t *testing.T) { t.Parallel() h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { panic("hook panic") }) @@ -203,7 +203,7 @@ func TestServeHTTPHooksInContext(t *testing.T) { var foundHooks atomic.Bool h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks) foundHooks.Store(ok && hooks != nil) return nil @@ -223,7 +223,7 @@ func TestServeHTTPAfterResponseHookReceivesNilOnSuccess(t *testing.T) { errChan := make(chan error, 1) h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { hookCalled.Store(true) errChan <- err diff --git a/framework/hooks.go b/framework/hooks.go deleted file mode 100644 index 532524e..0000000 --- a/framework/hooks.go +++ /dev/null @@ -1,102 +0,0 @@ -package framework - -import ( - "slices" - "sync" - - "github.com/studiolambda/cosmos/contract" -) - -// Hooks provides lifecycle hook registration for the HTTP -// request/response cycle. Middleware and handlers can attach -// callbacks that fire before headers are written, before the -// body is written, and after the response completes. -// -// All methods are safe for concurrent use. -type Hooks struct { - // mutex guards all hook slices. - mutex sync.Mutex - afterResponseHooks []contract.AfterResponseHook - beforeWriteHeaderHooks []contract.BeforeWriteHeaderHook - beforeWriteHooks []contract.BeforeWriteHook -} - -// NewHooks creates a Hooks instance with empty callback slices -// ready to accept registrations via the Before* and After* methods. -func NewHooks() *Hooks { - return &Hooks{ - beforeWriteHeaderHooks: []contract.BeforeWriteHeaderHook{}, - beforeWriteHooks: []contract.BeforeWriteHook{}, - afterResponseHooks: []contract.AfterResponseHook{}, - } -} - -// BeforeWriteHeader registers one or more callbacks that will be -// invoked just before the response status code is written. This -// is the last opportunity to inspect or modify headers. -func (hooks *Hooks) BeforeWriteHeader(callbacks ...contract.BeforeWriteHeaderHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.beforeWriteHeaderHooks = append(hooks.beforeWriteHeaderHooks, callbacks...) -} - -// BeforeWriteHeaderFuncs returns a reversed clone of the registered -// BeforeWriteHeader callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.beforeWriteHeaderHooks) - slices.Reverse(clone) - - return clone -} - -// BeforeWrite registers one or more callbacks that will be -// invoked just before the response body bytes are written. -// This is useful for logging, metrics, or content transformation. -func (hooks *Hooks) BeforeWrite(callbacks ...contract.BeforeWriteHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.beforeWriteHooks = append(hooks.beforeWriteHooks, callbacks...) -} - -// BeforeWriteFuncs returns a reversed clone of the registered -// BeforeWrite callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) BeforeWriteFuncs() []contract.BeforeWriteHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.beforeWriteHooks) - slices.Reverse(clone) - - return clone -} - -// AfterResponse registers one or more callbacks that will be -// invoked after the handler has completed and all response data -// has been written. The callback receives the handler's error -// (or nil if the handler succeeded). -func (hooks *Hooks) AfterResponse(callbacks ...contract.AfterResponseHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.afterResponseHooks = append(hooks.afterResponseHooks, callbacks...) -} - -// AfterResponseFuncs returns a reversed clone of the registered -// AfterResponse callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) AfterResponseFuncs() []contract.AfterResponseHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.afterResponseHooks) - slices.Reverse(clone) - - return clone -} diff --git a/framework/hooks_test.go b/framework/hooks_test.go deleted file mode 100644 index 108badf..0000000 --- a/framework/hooks_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package framework_test - -import ( - "net/http" - "testing" - - "github.com/studiolambda/cosmos/contract" - "github.com/studiolambda/cosmos/framework" - - "github.com/stretchr/testify/require" -) - -func TestNewHooksEmpty(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - - require.Empty(t, hooks.BeforeWriteHeaderFuncs()) - require.Empty(t, hooks.BeforeWriteFuncs()) - require.Empty(t, hooks.AfterResponseFuncs()) -} - -func TestBeforeWriteHeaderRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.BeforeWriteHeaderHook( - func(w http.ResponseWriter, status int) { order = append(order, 1) }, - ) - - second := contract.BeforeWriteHeaderHook( - func(w http.ResponseWriter, status int) { order = append(order, 2) }, - ) - - hooks.BeforeWriteHeader(first, second) - - funcs := hooks.BeforeWriteHeaderFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil, 0) - } - - // Reversed: second fires first. - require.Equal(t, []int{2, 1}, order) -} - -func TestBeforeWriteRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.BeforeWriteHook( - func(w http.ResponseWriter, content []byte) { order = append(order, 1) }, - ) - - second := contract.BeforeWriteHook( - func(w http.ResponseWriter, content []byte) { order = append(order, 2) }, - ) - - hooks.BeforeWrite(first, second) - - funcs := hooks.BeforeWriteFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil, nil) - } - - require.Equal(t, []int{2, 1}, order) -} - -func TestAfterResponseRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.AfterResponseHook( - func(err error) { order = append(order, 1) }, - ) - - second := contract.AfterResponseHook( - func(err error) { order = append(order, 2) }, - ) - - hooks.AfterResponse(first, second) - - funcs := hooks.AfterResponseFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil) - } - - require.Equal(t, []int{2, 1}, order) -} - -func TestHooksMultipleRegistrations(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - hooks.AfterResponse(func(err error) { order = append(order, 1) }) - hooks.AfterResponse(func(err error) { order = append(order, 2) }) - hooks.AfterResponse(func(err error) { order = append(order, 3) }) - - funcs := hooks.AfterResponseFuncs() - - require.Len(t, funcs, 3) - - for _, fn := range funcs { - fn(nil) - } - - // LIFO: 3, 2, 1. - require.Equal(t, []int{3, 2, 1}, order) -} - -func TestHooksFuncsReturnClone(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - - hooks.AfterResponse(func(err error) {}) - - funcs := hooks.AfterResponseFuncs() - funcs[0] = nil - - // Original should be unaffected by mutation of the returned slice. - require.NotNil(t, hooks.AfterResponseFuncs()[0]) -} diff --git a/framework/hooks_writer.go b/framework/hooks_writer.go index 71cbde0..1594d4c 100644 --- a/framework/hooks_writer.go +++ b/framework/hooks_writer.go @@ -4,6 +4,8 @@ import ( "log/slog" "net/http" "sync/atomic" + + "github.com/studiolambda/cosmos/contract" ) // ResponseWriter wraps an http.ResponseWriter to intercept @@ -14,7 +16,7 @@ import ( // sync/atomic for safe concurrent access. type ResponseWriter struct { http.ResponseWriter - *Hooks + *contract.Hooks writeHeaderCalled atomic.Bool } @@ -40,7 +42,7 @@ type WrappedResponseWriter interface { // the given hooks on write operations. If the underlying writer // implements http.Flusher, the returned value also satisfies // http.Flusher via ResponseWriterFlusher. -func NewResponseWriter(writer http.ResponseWriter, hooks *Hooks) WrappedResponseWriter { +func NewResponseWriter(writer http.ResponseWriter, hooks *contract.Hooks) WrappedResponseWriter { wrapped := &ResponseWriter{ ResponseWriter: writer, Hooks: hooks, diff --git a/framework/hooks_writer_test.go b/framework/hooks_writer_test.go index e3ba629..e68fcda 100644 --- a/framework/hooks_writer_test.go +++ b/framework/hooks_writer_test.go @@ -6,6 +6,7 @@ import ( "sync/atomic" "testing" + "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/framework" "github.com/stretchr/testify/require" @@ -27,7 +28,7 @@ func (writer *flusherWriter) Flush() { func TestNewResponseWriterNonFlusher(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(&plainWriter{rec}, hooks) @@ -39,7 +40,7 @@ func TestNewResponseWriterNonFlusher(t *testing.T) { func TestNewResponseWriterFlusher(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter( &flusherWriter{ResponseWriter: rec}, @@ -56,7 +57,7 @@ func TestNewResponseWriterFlusher(t *testing.T) { func TestWriteHeaderCalledInitiallyFalse(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -66,7 +67,7 @@ func TestWriteHeaderCalledInitiallyFalse(t *testing.T) { func TestWriteHeaderSetsCalledFlag(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -78,7 +79,7 @@ func TestWriteHeaderSetsCalledFlag(t *testing.T) { func TestWriteHeaderFiresBeforeWriteHeaderHooks(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -99,7 +100,7 @@ func TestWriteHeaderFiresBeforeWriteHeaderHooks(t *testing.T) { func TestWriteHeaderSecondCallIsNoop(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -121,7 +122,7 @@ func TestWriteHeaderSecondCallIsNoop(t *testing.T) { func TestWriteFiresBeforeWriteHooks(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -145,7 +146,7 @@ func TestWriteFiresBeforeWriteHooks(t *testing.T) { func TestWriteAutoCallsWriteHeaderWith200(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -159,7 +160,7 @@ func TestWriteAutoCallsWriteHeaderWith200(t *testing.T) { func TestWriteAfterWriteHeaderDoesNotCallAgain(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -183,7 +184,7 @@ func TestWriteAfterWriteHeaderDoesNotCallAgain(t *testing.T) { func TestBeforeWriteHeaderHookPanicIsRecovered(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -202,7 +203,7 @@ func TestBeforeWriteHeaderHookPanicIsRecovered(t *testing.T) { func TestBeforeWriteHookPanicIsRecovered(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -226,7 +227,7 @@ func TestBeforeWriteHookPanicIsRecovered(t *testing.T) { func TestResponseWriterUnwrapReturnsUnderlying(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() plain := &plainWriter{rec} wrapped := framework.NewResponseWriter(plain, hooks) @@ -244,7 +245,7 @@ func TestResponseWriterUnwrapReturnsUnderlying(t *testing.T) { func TestResponseControllerFlushThroughWrappedWriter(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() fw := &flusherWriter{ResponseWriter: rec} wrapped := framework.NewResponseWriter(fw, hooks) @@ -259,7 +260,7 @@ func TestResponseControllerFlushThroughWrappedWriter(t *testing.T) { func TestResponseWriterFlusherUnwrapReturnsUnderlying(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() fw := &flusherWriter{ResponseWriter: rec} wrapped := framework.NewResponseWriter(fw, hooks) @@ -277,7 +278,7 @@ func TestResponseWriterFlusherUnwrapReturnsUnderlying(t *testing.T) { func TestWriteHeaderHookReceivesUnderlyingWriter(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) From 3ca81f8cbe846afa338957b40cdba88537937c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 12:06:31 +0200 Subject: [PATCH 6/8] refactor(contract): simplify cursor API with encode function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace CursorExtractor interface and CursorExtractorFunc with a plain encode func(T) (string, error) parameter on NewCursor. Remove CursorEncoder interface, NewCursorWith, CursorValue, and CursorValueWith. New API: - NewCursor[T](items, perPage, hasNext, hasPrev, encode func(T) (string, error)) - MarshalCursor[V](value) — base64-JSON encode helper - UnmarshalCursor[V](cursor) — base64-JSON decode helper The encode function gives full control at the call site without requiring methods on domain types or hidden encoding behavior. --- contract/mock/pagination.go | 146 ------------------------------- contract/pagination.go | 121 +++++-------------------- contract/pagination_test.go | 170 +++++++++++++++--------------------- 3 files changed, 92 insertions(+), 345 deletions(-) delete mode 100644 contract/mock/pagination.go diff --git a/contract/mock/pagination.go b/contract/mock/pagination.go deleted file mode 100644 index 6e1c8a2..0000000 --- a/contract/mock/pagination.go +++ /dev/null @@ -1,146 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mock - -import ( - mock "github.com/stretchr/testify/mock" -) - -// NewCursorEncoderMock creates a new instance of CursorEncoderMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewCursorEncoderMock(t interface { - mock.TestingT - Cleanup(func()) -}) *CursorEncoderMock { - mock := &CursorEncoderMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// CursorEncoderMock is an autogenerated mock type for the CursorEncoder type -type CursorEncoderMock struct { - mock.Mock -} - -type CursorEncoderMock_Expecter struct { - mock *mock.Mock -} - -func (_m *CursorEncoderMock) EXPECT() *CursorEncoderMock_Expecter { - return &CursorEncoderMock_Expecter{mock: &_m.Mock} -} - -// Encode provides a mock function for the type CursorEncoderMock -func (_mock *CursorEncoderMock) Encode(value any) (string, error) { - ret := _mock.Called(value) - - if len(ret) == 0 { - panic("no return value specified for Encode") - } - - var r0 string - var r1 error - if returnFunc, ok := ret.Get(0).(func(any) (string, error)); ok { - return returnFunc(value) - } - if returnFunc, ok := ret.Get(0).(func(any) string); ok { - r0 = returnFunc(value) - } else { - r0 = ret.Get(0).(string) - } - if returnFunc, ok := ret.Get(1).(func(any) error); ok { - r1 = returnFunc(value) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// CursorEncoderMock_Encode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encode' -type CursorEncoderMock_Encode_Call struct { - *mock.Call -} - -// Encode is a helper method to define mock.On call -// - value any -func (_e *CursorEncoderMock_Expecter) Encode(value interface{}) *CursorEncoderMock_Encode_Call { - return &CursorEncoderMock_Encode_Call{Call: _e.mock.On("Encode", value)} -} - -func (_c *CursorEncoderMock_Encode_Call) Run(run func(value any)) *CursorEncoderMock_Encode_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0]) - }) - return _c -} - -func (_c *CursorEncoderMock_Encode_Call) Return(_a0 string, _a1 error) *CursorEncoderMock_Encode_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *CursorEncoderMock_Encode_Call) RunAndReturn(run func(any) (string, error)) *CursorEncoderMock_Encode_Call { - _c.Call.Return(run) - return _c -} - -// Decode provides a mock function for the type CursorEncoderMock -func (_mock *CursorEncoderMock) Decode(cursor string) (any, error) { - ret := _mock.Called(cursor) - - if len(ret) == 0 { - panic("no return value specified for Decode") - } - - var r0 any - var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (any, error)); ok { - return returnFunc(cursor) - } - if returnFunc, ok := ret.Get(0).(func(string) any); ok { - r0 = returnFunc(cursor) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0) - } - } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(cursor) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// CursorEncoderMock_Decode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decode' -type CursorEncoderMock_Decode_Call struct { - *mock.Call -} - -// Decode is a helper method to define mock.On call -// - cursor string -func (_e *CursorEncoderMock_Expecter) Decode(cursor interface{}) *CursorEncoderMock_Decode_Call { - return &CursorEncoderMock_Decode_Call{Call: _e.mock.On("Decode", cursor)} -} - -func (_c *CursorEncoderMock_Decode_Call) Run(run func(cursor string)) *CursorEncoderMock_Decode_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} - -func (_c *CursorEncoderMock_Decode_Call) Return(_a0 any, _a1 error) *CursorEncoderMock_Decode_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *CursorEncoderMock_Decode_Call) RunAndReturn(run func(string) (any, error)) *CursorEncoderMock_Decode_Call { - _c.Call.Return(run) - return _c -} diff --git a/contract/pagination.go b/contract/pagination.go index a8a0d34..73a924b 100644 --- a/contract/pagination.go +++ b/contract/pagination.go @@ -12,17 +12,6 @@ var ErrCursorEncode = errors.New("failed to encode cursor") // ErrCursorDecode is returned when a cursor string fails to decode. var ErrCursorDecode = errors.New("failed to decode cursor") -// CursorEncoder defines the encoding and decoding of cursor values -// into opaque cursor strings. Implementations control how cursor -// values are serialized for transport and deserialized on receipt. -type CursorEncoder interface { - // Encode converts a cursor value into an opaque cursor string. - Encode(value any) (string, error) - - // Decode converts an opaque cursor string back into a cursor value. - Decode(cursor string) (any, error) -} - // Page represents an offset-based paginated result set. type Page[T any] struct { Items []T `json:"items"` @@ -61,13 +50,12 @@ func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { } } -// NewCursor creates a new [Cursor] from the given items using a -// built-in base64-JSON encoding for cursor values. The extract -// function determines which value from each item becomes the cursor. -// When hasNext is true, the last item's extracted value becomes the -// next cursor. When hasPrev is true, the first item's extracted -// value becomes the previous cursor. -func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, extract func(T) any) (Cursor[T], error) { +// NewCursor creates a new [Cursor] from the given items. The encode +// function determines how each item is transformed into an opaque +// cursor string. When hasNext is true, the last item is encoded to +// produce the next cursor. When hasPrev is true, the first item is +// encoded to produce the previous cursor. +func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, encode func(T) (string, error)) (Cursor[T], error) { if items == nil { items = []T{} } @@ -82,7 +70,7 @@ func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, extract fun } if hasNext { - encoded, err := base64JSONEncode(extract(items[len(items)-1])) + encoded, err := encode(items[len(items)-1]) if err != nil { return result, errors.Join(ErrCursorEncode, err) @@ -92,7 +80,7 @@ func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, extract fun } if hasPrev { - encoded, err := base64JSONEncode(extract(items[0])) + encoded, err := encode(items[0]) if err != nil { return result, errors.Join(ErrCursorEncode, err) @@ -104,80 +92,10 @@ func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, extract fun return result, nil } -// NewCursorWith creates a new [Cursor] from the given items using -// a custom [CursorEncoder]. When hasNext is true, the last item is -// passed to the encoder to produce the next cursor. When hasPrev is -// true, the first item is passed to produce the previous cursor. -func NewCursorWith[T any](items []T, perPage int, hasNext, hasPrev bool, encoder CursorEncoder) (Cursor[T], error) { - if items == nil { - items = []T{} - } - - result := Cursor[T]{ - Items: items, - PerPage: perPage, - } - - if len(items) == 0 { - return result, nil - } - - if hasNext { - encoded, err := encoder.Encode(items[len(items)-1]) - - if err != nil { - return result, errors.Join(ErrCursorEncode, err) - } - - result.NextCursor = encoded - } - - if hasPrev { - encoded, err := encoder.Encode(items[0]) - - if err != nil { - return result, errors.Join(ErrCursorEncode, err) - } - - result.PrevCursor = encoded - } - - return result, nil -} - -// CursorValue decodes a cursor string produced by [NewCursor] back -// into the original cursor value using the built-in base64-JSON encoding. -func CursorValue[T any](cursor string) (T, error) { - var zero T - - raw, err := base64JSONDecode(cursor) - - if err != nil { - return zero, errors.Join(ErrCursorDecode, err) - } - - value, ok := raw.(T) - - if !ok { - return zero, ErrCursorDecode - } - - return value, nil -} - -// CursorValueWith decodes a cursor string using a custom [CursorEncoder]. -func CursorValueWith(cursor string, encoder CursorEncoder) (any, error) { - value, err := encoder.Decode(cursor) - - if err != nil { - return nil, errors.Join(ErrCursorDecode, err) - } - - return value, nil -} - -// base64JSONEncode encodes a value as JSON then base64url. -func base64JSONEncode(value any) (string, error) { +// MarshalCursor encodes a value into an opaque cursor string using +// JSON serialization and base64url encoding. Use this as the encoding +// helper inside the encode function passed to [NewCursor]. +func MarshalCursor[V any](value V) (string, error) { data, err := json.Marshal(value) if err != nil { @@ -187,18 +105,21 @@ func base64JSONEncode(value any) (string, error) { return base64.RawURLEncoding.EncodeToString(data), nil } -// base64JSONDecode decodes a base64url string then unmarshals as JSON. -func base64JSONDecode(cursor string) (any, error) { +// UnmarshalCursor decodes an opaque cursor string back into a typed +// value. It reverses the encoding performed by [MarshalCursor]. +func UnmarshalCursor[V any](cursor string) (V, error) { + var value V + data, err := base64.RawURLEncoding.DecodeString(cursor) if err != nil { - return nil, err + var zero V + return zero, errors.Join(ErrCursorDecode, err) } - var value any - if err := json.Unmarshal(data, &value); err != nil { - return nil, err + var zero V + return zero, errors.Join(ErrCursorDecode, err) } return value, nil diff --git a/contract/pagination_test.go b/contract/pagination_test.go index 9d57bca..9cf39ff 100644 --- a/contract/pagination_test.go +++ b/contract/pagination_test.go @@ -82,7 +82,9 @@ func TestNewCursorEncodesNextCursor(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, true, false, func(item int) any { return item }) + cursor, err := contract.NewCursor(items, 3, true, false, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.NotEmpty(t, cursor.NextCursor) @@ -93,7 +95,9 @@ func TestNewCursorEncodesPrevCursor(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, false, true, func(item int) any { return item }) + cursor, err := contract.NewCursor(items, 3, false, true, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.Empty(t, cursor.NextCursor) @@ -104,7 +108,9 @@ func TestNewCursorEncodesBothCursors(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, true, true, func(item int) any { return item }) + cursor, err := contract.NewCursor(items, 3, true, true, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.NotEmpty(t, cursor.NextCursor) @@ -114,7 +120,9 @@ func TestNewCursorEncodesBothCursors(t *testing.T) { func TestNewCursorEmptyItemsNoCursors(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor([]int{}, 10, true, true, func(item int) any { return item }) + cursor, err := contract.NewCursor([]int{}, 10, true, true, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.Empty(t, cursor.NextCursor) @@ -124,7 +132,9 @@ func TestNewCursorEmptyItemsNoCursors(t *testing.T) { func TestNewCursorNilItemsBecomesEmptySlice(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor[int](nil, 10, false, false, func(item int) any { return item }) + cursor, err := contract.NewCursor[int](nil, 10, false, false, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.NotNil(t, cursor.Items) @@ -134,20 +144,19 @@ func TestNewCursorNilItemsBecomesEmptySlice(t *testing.T) { func TestNewCursorPreservesPerPage(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor([]int{1}, 25, false, false, func(item int) any { return item }) + cursor, err := contract.NewCursor([]int{1}, 25, false, false, func(item int) (string, error) { + return contract.MarshalCursor(item) + }) require.NoError(t, err) require.Equal(t, 25, cursor.PerPage) } -func TestNewCursorEncodeErrorReturnsErrCursorEncode(t *testing.T) { +func TestNewCursorNextEncodeErrorReturnsErrCursorEncode(t *testing.T) { t.Parallel() - items := []int{1} - - // Channels cannot be JSON-encoded. - _, err := contract.NewCursor(items, 10, true, false, func(item int) any { - return make(chan int) + _, err := contract.NewCursor([]int{1}, 10, true, false, func(item int) (string, error) { + return "", errors.New("encode failed") }) require.ErrorIs(t, err, contract.ErrCursorEncode) @@ -156,150 +165,113 @@ func TestNewCursorEncodeErrorReturnsErrCursorEncode(t *testing.T) { func TestNewCursorPrevEncodeErrorReturnsErrCursorEncode(t *testing.T) { t.Parallel() - items := []int{1} - - _, err := contract.NewCursor(items, 10, false, true, func(item int) any { - return make(chan int) + _, err := contract.NewCursor([]int{1}, 10, false, true, func(item int) (string, error) { + return "", errors.New("encode failed") }) require.ErrorIs(t, err, contract.ErrCursorEncode) } -func TestCursorValueRoundTrip(t *testing.T) { +func TestNewCursorCustomEncoder(t *testing.T) { t.Parallel() - items := []int{10, 20, 30} - cursor, err := contract.NewCursor(items, 3, true, false, func(item int) any { return item }) - - require.NoError(t, err) - - value, err := contract.CursorValue[float64](cursor.NextCursor) + items := []int{1, 2, 3} + cursor, err := contract.NewCursor(items, 3, true, false, func(item int) (string, error) { + return "custom-cursor", nil + }) require.NoError(t, err) - require.Equal(t, float64(30), value) -} - -func TestCursorValueInvalidBase64ReturnsErrCursorDecode(t *testing.T) { - t.Parallel() - - _, err := contract.CursorValue[int]("not-valid-base64!!!") - - require.ErrorIs(t, err, contract.ErrCursorDecode) + require.Equal(t, "custom-cursor", cursor.NextCursor) } -func TestCursorValueTypeMismatchReturnsErrCursorDecode(t *testing.T) { +func TestMarshalUnmarshalCursorRoundTrip(t *testing.T) { t.Parallel() - items := []string{"hello"} - cursor, err := contract.NewCursor(items, 1, true, false, func(item string) any { return item }) + encoded, err := contract.MarshalCursor(42) require.NoError(t, err) - _, err = contract.CursorValue[int](cursor.NextCursor) - - require.ErrorIs(t, err, contract.ErrCursorDecode) -} + value, err := contract.UnmarshalCursor[int](encoded) -type failEncoder struct{} - -func (failEncoder) Encode(value any) (string, error) { - return "", errors.New("encode failed") -} - -func (failEncoder) Decode(cursor string) (any, error) { - return nil, errors.New("decode failed") -} - -type idEncoder struct{} - -func (idEncoder) Encode(value any) (string, error) { - return "custom-cursor", nil -} - -func (idEncoder) Decode(cursor string) (any, error) { - return cursor, nil + require.NoError(t, err) + require.Equal(t, 42, value) } -func TestNewCursorWithCustomEncoder(t *testing.T) { +func TestMarshalCursorStringRoundTrip(t *testing.T) { t.Parallel() - items := []int{1, 2, 3} - cursor, err := contract.NewCursorWith(items, 3, true, false, idEncoder{}) + encoded, err := contract.MarshalCursor("hello") require.NoError(t, err) - require.Equal(t, "custom-cursor", cursor.NextCursor) -} - -func TestNewCursorWithEncoderErrorReturnsErrCursorEncode(t *testing.T) { - t.Parallel() - items := []int{1} - _, err := contract.NewCursorWith(items, 1, true, false, failEncoder{}) + value, err := contract.UnmarshalCursor[string](encoded) - require.ErrorIs(t, err, contract.ErrCursorEncode) + require.NoError(t, err) + require.Equal(t, "hello", value) } -func TestNewCursorWithPrevEncoderErrorReturnsErrCursorEncode(t *testing.T) { +func TestMarshalCursorStructRoundTrip(t *testing.T) { t.Parallel() - items := []int{1} - _, err := contract.NewCursorWith(items, 1, false, true, failEncoder{}) + type key struct { + ID int `json:"id"` + Name string `json:"name"` + } - require.ErrorIs(t, err, contract.ErrCursorEncode) -} + original := key{ID: 7, Name: "test"} + encoded, err := contract.MarshalCursor(original) -func TestNewCursorWithEmptyItems(t *testing.T) { - t.Parallel() + require.NoError(t, err) - cursor, err := contract.NewCursorWith([]int{}, 10, true, true, idEncoder{}) + value, err := contract.UnmarshalCursor[key](encoded) require.NoError(t, err) - require.Empty(t, cursor.NextCursor) - require.Empty(t, cursor.PrevCursor) + require.Equal(t, original, value) } -func TestNewCursorWithNilItemsBecomesEmptySlice(t *testing.T) { +func TestMarshalCursorUnencodableReturnsError(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursorWith[int](nil, 10, false, false, idEncoder{}) + _, err := contract.MarshalCursor(make(chan int)) - require.NoError(t, err) - require.NotNil(t, cursor.Items) + require.Error(t, err) } -func TestNewCursorWithEncodesBothCursors(t *testing.T) { +func TestUnmarshalCursorInvalidBase64ReturnsErrCursorDecode(t *testing.T) { t.Parallel() - items := []int{1, 2, 3} - cursor, err := contract.NewCursorWith(items, 3, true, true, idEncoder{}) + _, err := contract.UnmarshalCursor[int]("not-valid-base64!!!") - require.NoError(t, err) - require.Equal(t, "custom-cursor", cursor.NextCursor) - require.Equal(t, "custom-cursor", cursor.PrevCursor) + require.ErrorIs(t, err, contract.ErrCursorDecode) } -func TestCursorValueInvalidJSONReturnsErrCursorDecode(t *testing.T) { +func TestUnmarshalCursorInvalidJSONReturnsErrCursorDecode(t *testing.T) { t.Parallel() // Valid base64 but invalid JSON. - _, err := contract.CursorValue[int]("bm90LWpzb24") + _, err := contract.UnmarshalCursor[int]("bm90LWpzb24") require.ErrorIs(t, err, contract.ErrCursorDecode) } -func TestCursorValueWithCustomDecoder(t *testing.T) { +func TestNewCursorWithMarshalCursorEndToEnd(t *testing.T) { t.Parallel() - value, err := contract.CursorValueWith("test-cursor", idEncoder{}) + type User struct { + ID int + Name string + } - require.NoError(t, err) - require.Equal(t, "test-cursor", value) -} + users := []User{{ID: 10, Name: "Alice"}, {ID: 20, Name: "Bob"}} -func TestCursorValueWithDecoderErrorReturnsErrCursorDecode(t *testing.T) { - t.Parallel() + cursor, err := contract.NewCursor(users, 2, true, false, func(u User) (string, error) { + return contract.MarshalCursor(u.ID) + }) - _, err := contract.CursorValueWith("anything", failEncoder{}) + require.NoError(t, err) - require.ErrorIs(t, err, contract.ErrCursorDecode) + id, err := contract.UnmarshalCursor[int](cursor.NextCursor) + + require.NoError(t, err) + require.Equal(t, 20, id) } From d6bf221f9023074386c64efb61671c0185ee96c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 13:13:15 +0200 Subject: [PATCH 7/8] fix(contract): removed unecesary test --- contract/response/static_test.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/contract/response/static_test.go b/contract/response/static_test.go index 9ecd2cb..7df5ecb 100644 --- a/contract/response/static_test.go +++ b/contract/response/static_test.go @@ -438,16 +438,6 @@ func TestSafeRedirectRejectsUnparseableURL(t *testing.T) { require.ErrorIs(t, err, response.ErrUnsafeRedirect) } -func TestErrUnsafeRedirectMessage(t *testing.T) { - t.Parallel() - - require.Equal( - t, - "unsafe redirect URL: must be a relative path", - response.ErrUnsafeRedirect.Error(), - ) -} - func TestStringTemplateBuffersBeforeWritingStatus(t *testing.T) { t.Parallel() From 2caa54d13f129a6d75ae1ad3daae14a7d6fea7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Sun, 31 May 2026 20:14:16 +0200 Subject: [PATCH 8/8] docs(contract): add godoc examples for pagination functions Add usage examples in doc comments for Paginate, CursorPaginate, CursorEncode, and CursorDecode. Rename functions from NewPage/NewCursor to Paginate/CursorPaginate and MarshalCursor/UnmarshalCursor to CursorEncode/CursorDecode for better API clarity. --- contract/pagination.go | 53 ++++++++++--- contract/pagination_example_test.go | 81 +++++++++++++++++++ contract/pagination_test.go | 118 ++++++++++++++-------------- 3 files changed, 183 insertions(+), 69 deletions(-) create mode 100644 contract/pagination_example_test.go diff --git a/contract/pagination.go b/contract/pagination.go index 73a924b..931bd05 100644 --- a/contract/pagination.go +++ b/contract/pagination.go @@ -29,10 +29,20 @@ type Cursor[T any] struct { PrevCursor string `json:"prev_cursor,omitempty"` } -// NewPage creates a new [Page] from the given items, total count, +// Paginate creates a new [Page] from the given items, total count, // current page number, and items per page. It computes the last // page automatically. The current page is clamped to [1, LastPage]. -func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { +// +// page, perPage := request.Pagination(r) +// +// var users []User +// db.Select(ctx, "SELECT * FROM users LIMIT $1 OFFSET $2", &users, perPage, (page-1)*perPage) +// +// var total int64 +// db.Find(ctx, "SELECT COUNT(*) FROM users", &total) +// +// result := contract.Paginate(users, total, page, perPage) +func Paginate[T any](items []T, total int64, page, perPage int) Page[T] { perPage = max(perPage, 1) lastPage := max(int((total+int64(perPage)-1)/int64(perPage)), 1) page = min(max(page, 1), lastPage) @@ -50,12 +60,31 @@ func NewPage[T any](items []T, total int64, page, perPage int) Page[T] { } } -// NewCursor creates a new [Cursor] from the given items. The encode +// CursorPaginate creates a new [Cursor] from the given items. The encode // function determines how each item is transformed into an opaque // cursor string. When hasNext is true, the last item is encoded to // produce the next cursor. When hasPrev is true, the first item is // encoded to produce the previous cursor. -func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, encode func(T) (string, error)) (Cursor[T], error) { +// +// cursor, perPage := request.CursorPagination(r) +// +// var startID int64 +// if cursor != "" { +// startID, _ = contract.CursorDecode[int64](cursor) +// } +// +// var items []FeedItem +// db.Select(ctx, "SELECT * FROM feed WHERE id > $1 ORDER BY id LIMIT $2", &items, startID, perPage+1) +// +// hasNext := len(items) > perPage +// if hasNext { +// items = items[:perPage] +// } +// +// result, err := contract.CursorPaginate(items, perPage, hasNext, cursor != "", func(item FeedItem) (string, error) { +// return contract.CursorEncode(item.ID) +// }) +func CursorPaginate[T any](items []T, perPage int, hasNext, hasPrev bool, encode func(T) (string, error)) (Cursor[T], error) { if items == nil { items = []T{} } @@ -92,10 +121,12 @@ func NewCursor[T any](items []T, perPage int, hasNext, hasPrev bool, encode func return result, nil } -// MarshalCursor encodes a value into an opaque cursor string using +// CursorEncode encodes a value into an opaque cursor string using // JSON serialization and base64url encoding. Use this as the encoding -// helper inside the encode function passed to [NewCursor]. -func MarshalCursor[V any](value V) (string, error) { +// helper inside the encode function passed to [CursorPaginate]. +// +// encoded, err := contract.CursorEncode(user.ID) +func CursorEncode[V any](value V) (string, error) { data, err := json.Marshal(value) if err != nil { @@ -105,9 +136,11 @@ func MarshalCursor[V any](value V) (string, error) { return base64.RawURLEncoding.EncodeToString(data), nil } -// UnmarshalCursor decodes an opaque cursor string back into a typed -// value. It reverses the encoding performed by [MarshalCursor]. -func UnmarshalCursor[V any](cursor string) (V, error) { +// CursorDecode decodes an opaque cursor string back into a typed +// value. It reverses the encoding performed by [CursorEncode]. +// +// id, err := contract.CursorDecode[int64](cursorString) +func CursorDecode[V any](cursor string) (V, error) { var value V data, err := base64.RawURLEncoding.DecodeString(cursor) diff --git a/contract/pagination_example_test.go b/contract/pagination_example_test.go new file mode 100644 index 0000000..c3afe16 --- /dev/null +++ b/contract/pagination_example_test.go @@ -0,0 +1,81 @@ +package contract_test + +import ( + "fmt" + + "github.com/studiolambda/cosmos/contract" +) + +func ExamplePaginate() { + items := []string{"a", "b", "c"} + + page := contract.Paginate(items, 10, 2, 3) + + fmt.Println(page.CurrentPage) + fmt.Println(page.LastPage) + fmt.Println(page.PerPage) + fmt.Println(page.Total) + fmt.Println(page.Items) + // Output: + // 2 + // 4 + // 3 + // 10 + // [a b c] +} + +func ExampleCursorPaginate() { + type User struct { + ID int + Name string + } + + users := []User{{ID: 10, Name: "Alice"}, {ID: 20, Name: "Bob"}} + + cursor, err := contract.CursorPaginate(users, 2, true, false, func(u User) (string, error) { + return contract.CursorEncode(u.ID) + }) + + if err != nil { + panic(err) + } + + fmt.Println(cursor.PerPage) + fmt.Println(cursor.NextCursor != "") + fmt.Println(cursor.PrevCursor) + + id, err := contract.CursorDecode[int](cursor.NextCursor) + + if err != nil { + panic(err) + } + + fmt.Println(id) + // Output: + // 2 + // true + // + // 20 +} + +func ExampleCursorEncode() { + encoded, err := contract.CursorEncode(42) + + if err != nil { + panic(err) + } + + fmt.Println(encoded) + // Output: NDI +} + +func ExampleCursorDecode() { + value, err := contract.CursorDecode[int]("NDI") + + if err != nil { + panic(err) + } + + fmt.Println(value) + // Output: 42 +} diff --git a/contract/pagination_test.go b/contract/pagination_test.go index 9cf39ff..1df13f7 100644 --- a/contract/pagination_test.go +++ b/contract/pagination_test.go @@ -8,69 +8,69 @@ import ( "github.com/studiolambda/cosmos/contract" ) -func TestNewPageComputesLastPage(t *testing.T) { +func TestPaginateComputesLastPage(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{"a", "b"}, 10, 1, 5) + page := contract.Paginate([]string{"a", "b"}, 10, 1, 5) require.Equal(t, 2, page.LastPage) } -func TestNewPageComputesLastPageWithRemainder(t *testing.T) { +func TestPaginateComputesLastPageWithRemainder(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{"a", "b"}, 11, 1, 5) + page := contract.Paginate([]string{"a", "b"}, 11, 1, 5) require.Equal(t, 3, page.LastPage) } -func TestNewPageClampsPageBelowOne(t *testing.T) { +func TestPaginateClampsPageBelowOne(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{"a"}, 10, 0, 5) + page := contract.Paginate([]string{"a"}, 10, 0, 5) require.Equal(t, 1, page.CurrentPage) } -func TestNewPageClampsPageAboveLastPage(t *testing.T) { +func TestPaginateClampsPageAboveLastPage(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{}, 10, 99, 5) + page := contract.Paginate([]string{}, 10, 99, 5) require.Equal(t, 2, page.CurrentPage) } -func TestNewPageClampsPerPageBelowOne(t *testing.T) { +func TestPaginateClampsPerPageBelowOne(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{"a"}, 5, 1, 0) + page := contract.Paginate([]string{"a"}, 5, 1, 0) require.Equal(t, 1, page.PerPage) } -func TestNewPageZeroTotalSetsLastPageOne(t *testing.T) { +func TestPaginateZeroTotalSetsLastPageOne(t *testing.T) { t.Parallel() - page := contract.NewPage([]string{}, 0, 1, 10) + page := contract.Paginate([]string{}, 0, 1, 10) require.Equal(t, 1, page.LastPage) require.Equal(t, 1, page.CurrentPage) } -func TestNewPageNilItemsBecomesEmptySlice(t *testing.T) { +func TestPaginateNilItemsBecomesEmptySlice(t *testing.T) { t.Parallel() - page := contract.NewPage[string](nil, 0, 1, 10) + page := contract.Paginate[string](nil, 0, 1, 10) require.NotNil(t, page.Items) require.Empty(t, page.Items) } -func TestNewPagePreservesItems(t *testing.T) { +func TestPaginatePreservesItems(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - page := contract.NewPage(items, 100, 3, 10) + page := contract.Paginate(items, 100, 3, 10) require.Equal(t, items, page.Items) require.Equal(t, int64(100), page.Total) @@ -78,12 +78,12 @@ func TestNewPagePreservesItems(t *testing.T) { require.Equal(t, 10, page.PerPage) } -func TestNewCursorEncodesNextCursor(t *testing.T) { +func TestCursorPaginateEncodesNextCursor(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, true, false, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate(items, 3, true, false, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) @@ -91,12 +91,12 @@ func TestNewCursorEncodesNextCursor(t *testing.T) { require.Empty(t, cursor.PrevCursor) } -func TestNewCursorEncodesPrevCursor(t *testing.T) { +func TestCursorPaginateEncodesPrevCursor(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, false, true, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate(items, 3, false, true, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) @@ -104,12 +104,12 @@ func TestNewCursorEncodesPrevCursor(t *testing.T) { require.NotEmpty(t, cursor.PrevCursor) } -func TestNewCursorEncodesBothCursors(t *testing.T) { +func TestCursorPaginateEncodesBothCursors(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, true, true, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate(items, 3, true, true, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) @@ -117,11 +117,11 @@ func TestNewCursorEncodesBothCursors(t *testing.T) { require.NotEmpty(t, cursor.PrevCursor) } -func TestNewCursorEmptyItemsNoCursors(t *testing.T) { +func TestCursorPaginateEmptyItemsNoCursors(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor([]int{}, 10, true, true, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate([]int{}, 10, true, true, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) @@ -129,11 +129,11 @@ func TestNewCursorEmptyItemsNoCursors(t *testing.T) { require.Empty(t, cursor.PrevCursor) } -func TestNewCursorNilItemsBecomesEmptySlice(t *testing.T) { +func TestCursorPaginateNilItemsBecomesEmptySlice(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor[int](nil, 10, false, false, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate[int](nil, 10, false, false, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) @@ -141,42 +141,42 @@ func TestNewCursorNilItemsBecomesEmptySlice(t *testing.T) { require.Empty(t, cursor.Items) } -func TestNewCursorPreservesPerPage(t *testing.T) { +func TestCursorPaginatePreservesPerPage(t *testing.T) { t.Parallel() - cursor, err := contract.NewCursor([]int{1}, 25, false, false, func(item int) (string, error) { - return contract.MarshalCursor(item) + cursor, err := contract.CursorPaginate([]int{1}, 25, false, false, func(item int) (string, error) { + return contract.CursorEncode(item) }) require.NoError(t, err) require.Equal(t, 25, cursor.PerPage) } -func TestNewCursorNextEncodeErrorReturnsErrCursorEncode(t *testing.T) { +func TestCursorPaginateNextEncodeErrorReturnsErrCursorEncode(t *testing.T) { t.Parallel() - _, err := contract.NewCursor([]int{1}, 10, true, false, func(item int) (string, error) { + _, err := contract.CursorPaginate([]int{1}, 10, true, false, func(item int) (string, error) { return "", errors.New("encode failed") }) require.ErrorIs(t, err, contract.ErrCursorEncode) } -func TestNewCursorPrevEncodeErrorReturnsErrCursorEncode(t *testing.T) { +func TestCursorPaginatePrevEncodeErrorReturnsErrCursorEncode(t *testing.T) { t.Parallel() - _, err := contract.NewCursor([]int{1}, 10, false, true, func(item int) (string, error) { + _, err := contract.CursorPaginate([]int{1}, 10, false, true, func(item int) (string, error) { return "", errors.New("encode failed") }) require.ErrorIs(t, err, contract.ErrCursorEncode) } -func TestNewCursorCustomEncoder(t *testing.T) { +func TestCursorPaginateCustomEncoder(t *testing.T) { t.Parallel() items := []int{1, 2, 3} - cursor, err := contract.NewCursor(items, 3, true, false, func(item int) (string, error) { + cursor, err := contract.CursorPaginate(items, 3, true, false, func(item int) (string, error) { return "custom-cursor", nil }) @@ -184,33 +184,33 @@ func TestNewCursorCustomEncoder(t *testing.T) { require.Equal(t, "custom-cursor", cursor.NextCursor) } -func TestMarshalUnmarshalCursorRoundTrip(t *testing.T) { +func TestCursorEncodeDecoderRoundTrip(t *testing.T) { t.Parallel() - encoded, err := contract.MarshalCursor(42) + encoded, err := contract.CursorEncode(42) require.NoError(t, err) - value, err := contract.UnmarshalCursor[int](encoded) + value, err := contract.CursorDecode[int](encoded) require.NoError(t, err) require.Equal(t, 42, value) } -func TestMarshalCursorStringRoundTrip(t *testing.T) { +func TestCursorEncodeStringRoundTrip(t *testing.T) { t.Parallel() - encoded, err := contract.MarshalCursor("hello") + encoded, err := contract.CursorEncode("hello") require.NoError(t, err) - value, err := contract.UnmarshalCursor[string](encoded) + value, err := contract.CursorDecode[string](encoded) require.NoError(t, err) require.Equal(t, "hello", value) } -func TestMarshalCursorStructRoundTrip(t *testing.T) { +func TestCursorEncodeStructRoundTrip(t *testing.T) { t.Parallel() type key struct { @@ -219,42 +219,42 @@ func TestMarshalCursorStructRoundTrip(t *testing.T) { } original := key{ID: 7, Name: "test"} - encoded, err := contract.MarshalCursor(original) + encoded, err := contract.CursorEncode(original) require.NoError(t, err) - value, err := contract.UnmarshalCursor[key](encoded) + value, err := contract.CursorDecode[key](encoded) require.NoError(t, err) require.Equal(t, original, value) } -func TestMarshalCursorUnencodableReturnsError(t *testing.T) { +func TestCursorEncodeUnencodableReturnsError(t *testing.T) { t.Parallel() - _, err := contract.MarshalCursor(make(chan int)) + _, err := contract.CursorEncode(make(chan int)) require.Error(t, err) } -func TestUnmarshalCursorInvalidBase64ReturnsErrCursorDecode(t *testing.T) { +func TestCursorDecodeInvalidBase64ReturnsErrCursorDecode(t *testing.T) { t.Parallel() - _, err := contract.UnmarshalCursor[int]("not-valid-base64!!!") + _, err := contract.CursorDecode[int]("not-valid-base64!!!") require.ErrorIs(t, err, contract.ErrCursorDecode) } -func TestUnmarshalCursorInvalidJSONReturnsErrCursorDecode(t *testing.T) { +func TestCursorDecodeInvalidJSONReturnsErrCursorDecode(t *testing.T) { t.Parallel() // Valid base64 but invalid JSON. - _, err := contract.UnmarshalCursor[int]("bm90LWpzb24") + _, err := contract.CursorDecode[int]("bm90LWpzb24") require.ErrorIs(t, err, contract.ErrCursorDecode) } -func TestNewCursorWithMarshalCursorEndToEnd(t *testing.T) { +func TestCursorPaginateWithMarshalCursorEndToEnd(t *testing.T) { t.Parallel() type User struct { @@ -264,13 +264,13 @@ func TestNewCursorWithMarshalCursorEndToEnd(t *testing.T) { users := []User{{ID: 10, Name: "Alice"}, {ID: 20, Name: "Bob"}} - cursor, err := contract.NewCursor(users, 2, true, false, func(u User) (string, error) { - return contract.MarshalCursor(u.ID) + cursor, err := contract.CursorPaginate(users, 2, true, false, func(u User) (string, error) { + return contract.CursorEncode(u.ID) }) require.NoError(t, err) - id, err := contract.UnmarshalCursor[int](cursor.NextCursor) + id, err := contract.CursorDecode[int](cursor.NextCursor) require.NoError(t, err) require.Equal(t, 20, id)