Skip to content

Commit 43a4ecb

Browse files
committed
feat: oapi codegen response validator
1 parent d5bc443 commit 43a4ecb

8 files changed

Lines changed: 232 additions & 2 deletions

File tree

api/v3/oasmiddleware/validator.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ func ValidateRequest(validationRouter routers.Router, opts ValidateRequestOption
6666
type ValidateResponseOption struct {
6767
// ResponseValidationErrorHook is called when the route response body is not validated
6868
ResponseValidationErrorHook ResponseValidationFunc
69+
// RouteFilterHook is called after the route is found; return false to skip validation for that route.
70+
// If nil, all matched routes are validated.
71+
RouteFilterHook func(*routers.Route) bool
6972
// FilterOptions are the openapi3filter option to pass to the underlying lib
7073
FilterOptions *openapi3filter.Options
7174
}
@@ -75,8 +78,6 @@ type ValidateResponseOption struct {
7578
func ValidateResponse(validationRouter routers.Router, opts ValidateResponseOption) func(h http.Handler) http.Handler {
7679
return func(h http.Handler) http.Handler {
7780
fn := func(w http.ResponseWriter, r *http.Request) {
78-
var err error
79-
8081
route, pathParams, err := validationRouter.FindRoute(r)
8182

8283
if err != nil {
@@ -85,6 +86,11 @@ func ValidateResponse(validationRouter routers.Router, opts ValidateResponseOpti
8586
opts.ResponseValidationErrorHook(err, r)
8687
}
8788
} else {
89+
if opts.RouteFilterHook != nil && !opts.RouteFilterHook(route) {
90+
h.ServeHTTP(w, r)
91+
return
92+
}
93+
8894
// need to wrap std lib response to access the body
8995
rww := NewResponseWriterWrapper(w)
9096

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package oasmiddleware_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/getkin/kin-openapi/routers"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
api "github.com/openmeterio/openmeter/api/v3"
13+
"github.com/openmeterio/openmeter/api/v3/oasmiddleware"
14+
)
15+
16+
// TestValidateResponse_Violation proves that the ValidateResponse middleware fires its
17+
// error hook when a handler returns a response body that violates the OpenAPI spec.
18+
//
19+
// GET /openmeter/addons requires a 200 body with both "data" and "meta" fields.
20+
// Returning {} omits both required fields and must trigger a violation.
21+
func TestValidateResponse_Violation(t *testing.T) {
22+
swagger, err := api.GetSwagger()
23+
require.NoError(t, err)
24+
25+
swagger.Servers = nil
26+
27+
router, err := oasmiddleware.NewValidationRouter(t.Context(), swagger, &oasmiddleware.ValidationRouterOpts{
28+
DeleteServers: true,
29+
})
30+
require.NoError(t, err)
31+
32+
var gotErr error
33+
mw := oasmiddleware.ValidateResponse(router, oasmiddleware.ValidateResponseOption{
34+
ResponseValidationErrorHook: func(err error, r *http.Request) {
35+
gotErr = err
36+
},
37+
})
38+
39+
badHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
40+
w.Header().Set("Content-Type", "application/json")
41+
w.WriteHeader(http.StatusOK)
42+
_, _ = w.Write([]byte(`{}`))
43+
})
44+
45+
req := httptest.NewRequest(http.MethodGet, "/openmeter/addons", nil)
46+
rec := httptest.NewRecorder()
47+
48+
mw(badHandler).ServeHTTP(rec, req)
49+
50+
assert.Error(t, gotErr, "expected a validation error for missing required fields")
51+
t.Logf("validation error (expected): %v", gotErr)
52+
}
53+
54+
// TestValidateResponse_Clean proves that a well-formed response does not trigger the error hook.
55+
func TestValidateResponse_Clean(t *testing.T) {
56+
swagger, err := api.GetSwagger()
57+
require.NoError(t, err)
58+
59+
swagger.Servers = nil
60+
61+
router, err := oasmiddleware.NewValidationRouter(t.Context(), swagger, &oasmiddleware.ValidationRouterOpts{
62+
DeleteServers: true,
63+
})
64+
require.NoError(t, err)
65+
66+
var gotErr error
67+
mw := oasmiddleware.ValidateResponse(router, oasmiddleware.ValidateResponseOption{
68+
ResponseValidationErrorHook: func(err error, r *http.Request) {
69+
gotErr = err
70+
},
71+
})
72+
73+
goodHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
74+
w.Header().Set("Content-Type", "application/json")
75+
w.WriteHeader(http.StatusOK)
76+
_, _ = w.Write([]byte(`{"data":[],"meta":{"page":{"number":0,"size":100,"total":0}}}`))
77+
})
78+
79+
req := httptest.NewRequest(http.MethodGet, "/openmeter/addons", nil)
80+
rec := httptest.NewRecorder()
81+
82+
mw(goodHandler).ServeHTTP(rec, req)
83+
84+
assert.NoError(t, gotErr, "expected no validation error for a well-formed response")
85+
}
86+
87+
// TestValidateResponse_RouteFilterSkipsValidation proves that when RouteFilterHook returns false,
88+
// the response body is neither buffered nor validated — even if it would otherwise violate the spec.
89+
// The filter is the per-route gate that lets callers (e.g. unstable-only mode) avoid the
90+
// buffering overhead on routes they don't care about.
91+
func TestValidateResponse_RouteFilterSkipsValidation(t *testing.T) {
92+
swagger, err := api.GetSwagger()
93+
require.NoError(t, err)
94+
95+
swagger.Servers = nil
96+
97+
router, err := oasmiddleware.NewValidationRouter(t.Context(), swagger, &oasmiddleware.ValidationRouterOpts{
98+
DeleteServers: true,
99+
})
100+
require.NoError(t, err)
101+
102+
var (
103+
gotErr error
104+
filteredRoute *routers.Route
105+
)
106+
mw := oasmiddleware.ValidateResponse(router, oasmiddleware.ValidateResponseOption{
107+
RouteFilterHook: func(route *routers.Route) bool {
108+
filteredRoute = route
109+
return false
110+
},
111+
ResponseValidationErrorHook: func(err error, r *http.Request) {
112+
gotErr = err
113+
},
114+
})
115+
116+
// Body that would fail validation if the filter let it through.
117+
badHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118+
w.Header().Set("Content-Type", "application/json")
119+
w.WriteHeader(http.StatusOK)
120+
_, _ = w.Write([]byte(`{}`))
121+
})
122+
123+
req := httptest.NewRequest(http.MethodGet, "/openmeter/addons", nil)
124+
rec := httptest.NewRecorder()
125+
126+
mw(badHandler).ServeHTTP(rec, req)
127+
128+
require.NotNil(t, filteredRoute, "RouteFilterHook should have been invoked with the matched route")
129+
assert.Equal(t, "/openmeter/addons", filteredRoute.Path)
130+
assert.NoError(t, gotErr, "validation error hook must not fire when the filter returns false")
131+
assert.Equal(t, http.StatusOK, rec.Code, "client response should still be served")
132+
assert.Equal(t, `{}`, rec.Body.String(), "client response body should still be served")
133+
}

api/v3/server/server.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package server
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
78
"log/slog"
89
"net/http"
910

1011
"github.com/getkin/kin-openapi/openapi3"
1112
"github.com/getkin/kin-openapi/openapi3filter"
13+
"github.com/getkin/kin-openapi/routers"
1214
"github.com/go-chi/chi/v5"
1315
"github.com/samber/lo"
1416

@@ -72,6 +74,7 @@ type Config struct {
7274
Middlewares []server.MiddlewareFunc
7375
PostAuthMiddlewares []server.MiddlewareFunc
7476
Credits config.CreditsConfiguration
77+
ResponseValidation config.ResponseValidationConfig
7578

7679
// services
7780
AddonService addon.Service
@@ -384,6 +387,15 @@ func (s *Server) RegisterRoutes(r chi.Router) error {
384387
validationMiddleware,
385388
}
386389

390+
if s.ResponseValidation.Mode.Enabled() {
391+
middlewares = append(middlewares, oasmiddleware.ValidateResponse(validationRouter, oasmiddleware.ValidateResponseOption{
392+
RouteFilterHook: buildResponseValidationRouteFilter(s.ResponseValidation),
393+
ResponseValidationErrorHook: func(err error, r *http.Request) {
394+
slog.WarnContext(r.Context(), "response validation failed", slog.String("method", r.Method), slog.String("path", r.URL.Path), slog.Any("error", err))
395+
},
396+
}))
397+
}
398+
387399
postAuthMiddlewares := lo.Map(s.PostAuthMiddlewares, func(mwf server.MiddlewareFunc, _ int) api.MiddlewareFunc {
388400
return api.MiddlewareFunc(mwf)
389401
})
@@ -399,3 +411,30 @@ func (s *Server) RegisterRoutes(r chi.Router) error {
399411

400412
return nil
401413
}
414+
415+
// buildResponseValidationRouteFilter returns a route filter for response validation.
416+
// In "all" mode the filter is nil (every route is validated). In "unstable" mode only
417+
// operations marked x-unstable: true in the spec are validated.
418+
func buildResponseValidationRouteFilter(cfg config.ResponseValidationConfig) func(*routers.Route) bool {
419+
if cfg.Mode != config.ResponseValidationModeUnstable {
420+
return nil
421+
}
422+
return func(route *routers.Route) bool {
423+
if route.Operation == nil {
424+
return false
425+
}
426+
extVal, ok := route.Operation.Extensions["x-unstable"]
427+
if !ok {
428+
return false
429+
}
430+
switch v := extVal.(type) {
431+
case json.RawMessage:
432+
var b bool
433+
return json.Unmarshal(v, &b) == nil && b
434+
case bool:
435+
return v
436+
default:
437+
return false
438+
}
439+
}
440+
}

app/config/config_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ func TestComplete(t *testing.T) {
418418
ReadTimeout: 60 * time.Second,
419419
WriteTimeout: 90 * time.Second,
420420
IdleTimeout: 120 * time.Second,
421+
ResponseValidation: ResponseValidationConfig{
422+
Mode: ResponseValidationModeOff,
423+
},
421424
},
422425
ProgressManager: ProgressManagerConfiguration{
423426
Enabled: false,

app/config/server.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,37 @@ type ServerConfig struct {
1313
ReadTimeout time.Duration
1414
WriteTimeout time.Duration
1515
IdleTimeout time.Duration
16+
17+
ResponseValidation ResponseValidationConfig
18+
}
19+
20+
// ResponseValidationConfig controls optional post-response OpenAPI validation on the v3 API.
21+
type ResponseValidationConfig struct {
22+
Mode ResponseValidationMode
23+
}
24+
25+
type ResponseValidationMode string
26+
27+
const (
28+
// ResponseValidationModeOff disables response validation. This is the default.
29+
ResponseValidationModeOff ResponseValidationMode = "off"
30+
// ResponseValidationModeUnstable validates only routes marked x-unstable: true in the spec.
31+
ResponseValidationModeUnstable ResponseValidationMode = "unstable"
32+
// ResponseValidationModeAll validates every route in the v3 spec.
33+
ResponseValidationModeAll ResponseValidationMode = "all"
34+
)
35+
36+
func (m ResponseValidationMode) Enabled() bool {
37+
return m != "" && m != ResponseValidationModeOff
38+
}
39+
40+
func (m ResponseValidationMode) Validate() error {
41+
switch m {
42+
case "", ResponseValidationModeOff, ResponseValidationModeUnstable, ResponseValidationModeAll:
43+
return nil
44+
default:
45+
return errors.New("invalid response validation mode (allowed: off, unstable, all)")
46+
}
1647
}
1748

1849
func (c ServerConfig) Validate() error {
@@ -34,6 +65,10 @@ func (c ServerConfig) Validate() error {
3465
errs = append(errs, errors.New("idleTimeout must be non-negative"))
3566
}
3667

68+
if err := c.ResponseValidation.Mode.Validate(); err != nil {
69+
errs = append(errs, err)
70+
}
71+
3772
return errors.Join(errs...)
3873
}
3974

@@ -45,4 +80,6 @@ func ConfigureServer(v *viper.Viper, prefixes ...string) {
4580
v.SetDefault(prefixer("readTimeout"), 60*time.Second)
4681
v.SetDefault(prefixer("writeTimeout"), 90*time.Second)
4782
v.SetDefault(prefixer("idleTimeout"), 120*time.Second)
83+
84+
v.SetDefault(prefixer("responseValidation.mode"), string(ResponseValidationModeOff))
4885
}

cmd/server/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ func main() {
192192
},
193193
RouterHooks: lo.FromPtr(app.RouterHooks),
194194
PostAuthMiddlewares: app.PostAuthMiddlewares,
195+
ResponseValidation: conf.Server.ResponseValidation,
195196
})
196197
if err != nil {
197198
logger.Error("failed to create server", "error", err)

config.example.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ server:
3333
writeTimeout: 90s
3434
# idleTimeout is the maximum time a keep-alive connection remains idle before being closed.
3535
idleTimeout: 120s
36+
responseValidation:
37+
# mode selects which v3 routes have their response bodies validated against the OpenAPI spec.
38+
# Violations are logged as warnings; the client response is never affected.
39+
# Validation buffers the full response body in memory.
40+
# off — validation disabled (default)
41+
# unstable — validate only routes marked x-unstable: true in the spec
42+
# all — validate every v3 route (highest overhead)
43+
mode: "off"
3644

3745
#ingest:
3846
# kafka:

openmeter/server/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"github.com/openmeterio/openmeter/api"
1919
v3server "github.com/openmeterio/openmeter/api/v3/server"
20+
appconfig "github.com/openmeterio/openmeter/app/config"
2021
"github.com/openmeterio/openmeter/openmeter/namespace/namespacedriver"
2122
"github.com/openmeterio/openmeter/openmeter/portal/authenticator"
2223
"github.com/openmeterio/openmeter/openmeter/server/router"
@@ -67,6 +68,7 @@ type Config struct {
6768
RouterConfig router.Config
6869
RouterHooks RouterHooks
6970
PostAuthMiddlewares PostAuthMiddlewares
71+
ResponseValidation appconfig.ResponseValidationConfig
7072
}
7173

7274
func NewServer(config *Config) (*Server, error) {
@@ -136,6 +138,7 @@ func NewServer(config *Config) (*Server, error) {
136138
FeatureConnector: config.RouterConfig.FeatureConnector,
137139
Middlewares: v3Middlewares,
138140
PostAuthMiddlewares: config.PostAuthMiddlewares,
141+
ResponseValidation: config.ResponseValidation,
139142
})
140143
if err != nil {
141144
slog.Error("failed to create v3 API", "error", err)

0 commit comments

Comments
 (0)