Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/api/handler_balances_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ import (
)

func (m *MainHandler) listBalancesHandler(w http.ResponseWriter, r *http.Request) {
query := readPaginatedRequest(r, func(r *http.Request) wallet.ListBalances {
query, err := readPaginatedRequest(r, func(r *http.Request) wallet.ListBalances {
return wallet.ListBalances{
WalletID: chi.URLParam(r, "walletID"),
Metadata: getQueryMap(r.URL.Query(), "metadata"),
}
})
if err != nil {
badRequest(w, ErrorCodeValidation, err)
return
}

holds, err := m.manager.ListBalances(r.Context(), query)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/api/handler_holds_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import (
)

func (m *MainHandler) listHoldsHandler(w http.ResponseWriter, r *http.Request) {
query := readPaginatedRequest(r, func(r *http.Request) wallet.ListHolds {
query, err := readPaginatedRequest(r, func(r *http.Request) wallet.ListHolds {
return wallet.ListHolds{
WalletID: r.URL.Query().Get("walletID"),
Metadata: getQueryMap(r.URL.Query(), "metadata"),
}
})
if err != nil {
badRequest(w, ErrorCodeValidation, err)
return
}

holds, err := m.manager.ListHolds(r.Context(), query)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/api/handler_transactions_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ import (
)

func (m *MainHandler) listTransactions(w http.ResponseWriter, r *http.Request) {
query := readPaginatedRequest[wallet.ListTransactions](r, func(r *http.Request) wallet.ListTransactions {
query, err := readPaginatedRequest[wallet.ListTransactions](r, func(r *http.Request) wallet.ListTransactions {
return wallet.ListTransactions{
WalletID: r.URL.Query().Get("walletID"),
}
})
if err != nil {
badRequest(w, ErrorCodeValidation, err)
return
}
transactions, err := m.manager.ListTransactions(r.Context(), query)
if err != nil {
internalError(w, r, err)
Expand Down
6 changes: 5 additions & 1 deletion pkg/api/handler_wallets_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ import (
)

func (m *MainHandler) listWalletsHandler(w http.ResponseWriter, r *http.Request) {
query := readPaginatedRequest[wallet.ListWallets](r, func(r *http.Request) wallet.ListWallets {
query, err := readPaginatedRequest[wallet.ListWallets](r, func(r *http.Request) wallet.ListWallets {
return wallet.ListWallets{
Metadata: getQueryMap(r.URL.Query(), "metadata"),
Name: r.URL.Query().Get("name"),
ExpandBalances: r.URL.Query().Get("expand") == "balances",
}
})
if err != nil {
badRequest(w, ErrorCodeValidation, err)
return
}
response, err := m.manager.ListWallets(r.Context(), query)
if err != nil {
internalError(w, r, err)
Expand Down
46 changes: 46 additions & 0 deletions pkg/api/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package api

import (
"bytes"
"encoding/json"
"io"
"net/http"

"github.com/ThreeDotsLabs/watermill/message"
Expand All @@ -16,6 +19,47 @@ import (
"github.com/go-chi/chi/v5/middleware"
)

// maxRequestBodyBytes caps the size of request bodies the service will read.
// The JSON payloads handled here are small; this protects the service (and the
// audit middleware, which buffers the whole body) from memory-exhaustion DoS.
const maxRequestBodyBytes = 1 << 20 // 1 MiB

// limitRequestBody reads the request body up to maxRequestBodyBytes and rejects
// anything larger with 413. It buffers the (bounded) body and resets r.Body so
// the audit middleware and the handlers can still read it.
//
// We do this instead of http.MaxBytesReader because the audit middleware reads
// the body with io.ReadAll *before* the handler and turns any non-EOF error
// (including MaxBytesReader's overflow error) into a 500 — so an oversized body
// would surface as "500 failed to read request body" instead of 413.
func limitRequestBody(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil && r.Body != http.NoBody {
buf, err := io.ReadAll(io.LimitReader(r.Body, maxRequestBodyBytes+1))
_ = r.Body.Close()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(sharedapi.ErrorResponse{
ErrorCode: ErrorCodeValidation,
ErrorMessage: "failed to read request body",
})
return
}
if int64(len(buf)) > maxRequestBodyBytes {
w.WriteHeader(http.StatusRequestEntityTooLarge)
_ = json.NewEncoder(w).Encode(sharedapi.ErrorResponse{
ErrorCode: "REQUEST_TOO_LARGE",
ErrorMessage: "request body too large",
})
return
}
r.Body = io.NopCloser(bytes.NewReader(buf))
r.ContentLength = int64(len(buf))
}
handler.ServeHTTP(w, r)
})
}

func NewRouter(
manager *wallet.Manager,
healthController *sharedhealth.HealthController,
Expand All @@ -25,12 +69,14 @@ func NewRouter(
) *chi.Mux {
r := chi.NewRouter()

r.Use(middleware.Recoverer)
r.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
handler.ServeHTTP(w, r)
})
})
r.Use(limitRequestBody)
r.Use(httpaudit.Middleware(publisher, "audit-events", "wallets", nil))

r.Get("/_healthcheck", healthController.Check)
Expand Down
30 changes: 22 additions & 8 deletions pkg/api/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
Expand All @@ -14,7 +15,10 @@ import (
wallet "github.com/formancehq/wallets/pkg"
)

const defaultLimit = 15
const (
defaultLimit = 15
maxPageSize = 100
)

func notFound(w http.ResponseWriter) {
w.WriteHeader(http.StatusNotFound)
Expand Down Expand Up @@ -81,31 +85,41 @@ func parsePaginationToken(r *http.Request) string {
return r.URL.Query().Get("cursor")
}

func parsePageSize(r *http.Request) int {
func parsePageSize(r *http.Request) (int, error) {
pageSize := r.URL.Query().Get("pageSize")
if pageSize == "" {
return defaultLimit
return defaultLimit, nil
}

v, err := strconv.ParseInt(pageSize, 10, 32)
if err != nil {
panic(err)
return 0, fmt.Errorf("invalid pageSize: %w", err)
}
if v < 1 {
return 0, fmt.Errorf("pageSize must be a positive integer")
}
if v > maxPageSize {
v = maxPageSize
}
return int(v)
return int(v), nil
}

func readPaginatedRequest[T any](r *http.Request, f func(r *http.Request) T) wallet.ListQuery[T] {
func readPaginatedRequest[T any](r *http.Request, f func(r *http.Request) T) (wallet.ListQuery[T], error) {
pageSize, err := parsePageSize(r)
if err != nil {
return wallet.ListQuery[T]{}, err
}
var payload T
if f != nil {
payload = f(r)
}
return wallet.ListQuery[T]{
Pagination: wallet.Pagination{
Limit: parsePageSize(r),
Limit: pageSize,
PaginationToken: parsePaginationToken(r),
},
Payload: payload,
}
}, nil
}

func getQueryMap(m map[string][]string, key string) map[string]string {
Expand Down
68 changes: 68 additions & 0 deletions pkg/api/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,74 @@ import (
"github.com/stretchr/testify/require"
)

func TestParsePageSize(t *testing.T) {
t.Parallel()

for _, tc := range []struct {
query string
expected int
expectErr bool
}{
{query: "", expected: defaultLimit},
{query: "pageSize=20", expected: 20},
{query: "pageSize=100", expected: maxPageSize},
{query: "pageSize=100000", expected: maxPageSize},
{query: "pageSize=abc", expectErr: true},
{query: "pageSize=0", expectErr: true},
{query: "pageSize=-5", expectErr: true},
} {
tc := tc
t.Run(tc.query, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/?"+tc.query, nil)
v, err := parsePageSize(req)
if tc.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expected, v)
})
}
}

func TestListHandlerRejectsInvalidPageSize(t *testing.T) {
t.Parallel()

testEnv := newTestEnv(
WithListAccounts(func(ctx context.Context, ledger string, query wallet.ListAccountsQuery) (*wallet.AccountsCursorResponseCursor, error) {
return &wallet.AccountsCursorResponseCursor{}, nil
}),
)
req := httptest.NewRequest(http.MethodGet, "/wallets?pageSize=abc", nil)
rec := httptest.NewRecorder()
testEnv.Router().ServeHTTP(rec, req)

require.Equal(t, http.StatusBadRequest, rec.Result().StatusCode)
require.Equal(t, ErrorCodeValidation, readErrorResponse(t, rec).ErrorCode)
}

func TestRequestBodyTooLarge(t *testing.T) {
t.Parallel()

testEnv := newTestEnv(
WithAddMetadataToAccount(func(ctx context.Context, ledger, account, ik string, metadata map[string]string) error {
return nil
}),
)

// A body larger than maxRequestBodyBytes must be rejected with 413 before
// the audit middleware reads it — not surface as a 500.
body := bytes.NewReader(make([]byte, (1<<20)+1024))
req := httptest.NewRequest(http.MethodPost, "/wallets", body)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
testEnv.Router().ServeHTTP(rec, req)

require.Equal(t, http.StatusRequestEntityTooLarge, rec.Result().StatusCode)
require.Equal(t, "REQUEST_TOO_LARGE", readErrorResponse(t, rec).ErrorCode)
}

func readErrorResponse(t *testing.T, rec *httptest.ResponseRecorder) *sharedapi.ErrorResponse {
t.Helper()
ret := &sharedapi.ErrorResponse{}
Expand Down
Loading