diff --git a/pkg/api/grpc_http_gateway.go b/pkg/api/grpc_http_gateway.go index b7b5894..b3a3866 100644 --- a/pkg/api/grpc_http_gateway.go +++ b/pkg/api/grpc_http_gateway.go @@ -16,7 +16,13 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFunc, oauthHandlers map[string]http.HandlerFunc) (http.Handler, error) { +type HTTPRoute struct { + Method string + Path string + Handler http.HandlerFunc +} + +func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFunc, routes []HTTPRoute) (http.Handler, error) { mux := runtime.NewServeMux() var opts []grpc.DialOption @@ -53,9 +59,8 @@ func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFu if metricsHandler != nil { handleGET(mux, "/metrics", metricsHandler) } - // Register fosite oauth endpoints - for path, h := range oauthHandlers { - handlePOST(mux, path, h) + for _, route := range routes { + handleRoute(mux, route) } if conf.ServeDebug { @@ -88,19 +93,14 @@ func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFu } func handleGET(mux *runtime.ServeMux, path string, handler http.HandlerFunc) { - err := mux.HandlePath("GET", path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { - handler(w, r) - }) - if err != nil { - panic(fmt.Errorf("%w: unable to register http handler %s", err, path)) - } + handleRoute(mux, HTTPRoute{Method: http.MethodGet, Path: path, Handler: handler}) } -func handlePOST(mux *runtime.ServeMux, path string, handler http.HandlerFunc) { - err := mux.HandlePath("POST", path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { - handler(w, r) +func handleRoute(mux *runtime.ServeMux, route HTTPRoute) { + err := mux.HandlePath(route.Method, route.Path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + route.Handler(w, r) }) if err != nil { - panic(fmt.Errorf("%w: unable to register http handler %s", err, path)) + panic(fmt.Errorf("%w: unable to register %s http handler %s", err, route.Method, route.Path)) } } diff --git a/pkg/app/start.go b/pkg/app/start.go index 0c219a0..f1abde9 100644 --- a/pkg/app/start.go +++ b/pkg/app/start.go @@ -120,13 +120,15 @@ func Start(ctx context.Context, conf config.Config, build config.Build) error { if conf.Metrics.Enabled { metricsHandler = promhttp.Handler().ServeHTTP } - oauthHandlers := map[string]http.HandlerFunc{ - "/api/oauth/token": authServer.TokenEndpoint, - "/api/oauth/auth": authServer.AuthEndpoint, - "/api/oauth/revoke": authServer.RevokeEndpoint, - "/api/oauth/introspect": authServer.IntrospectionEndpoint, + httpRoutes := []api.HTTPRoute{ + {Method: http.MethodPost, Path: "/api/oauth/token", Handler: authServer.TokenEndpoint}, + {Method: http.MethodPost, Path: "/api/oauth/auth", Handler: authServer.AuthEndpoint}, + {Method: http.MethodPost, Path: "/api/oauth/revoke", Handler: authServer.RevokeEndpoint}, + {Method: http.MethodPost, Path: "/api/oauth/introspect", Handler: authServer.IntrospectionEndpoint}, + {Method: http.MethodGet, Path: "/.well-known/ceph-api", Handler: authServer.DiscoveryEndpoint}, + {Method: http.MethodGet, Path: "/.well-known/ceph-api/jwks.json", Handler: authServer.JWKSEndpoint}, } - httpServer, err := api.GRPCGateway(ctx, conf.Api, metricsHandler, oauthHandlers) + httpServer, err := api.GRPCGateway(ctx, conf.Api, metricsHandler, httpRoutes) if err != nil { return err } diff --git a/pkg/auth/discovery_handler.go b/pkg/auth/discovery_handler.go new file mode 100644 index 0000000..781f322 --- /dev/null +++ b/pkg/auth/discovery_handler.go @@ -0,0 +1,75 @@ +package auth + +import ( + "encoding/base64" + "encoding/json" + "math/big" + "net/http" +) + +const ( + tokenEndpoint = "/api/oauth/token" + revokeEndpoint = "/api/oauth/revoke" + jwksURI = "/.well-known/ceph-api/jwks.json" +) + +type discoveryDocument struct { + ClusterID string `json:"cluster_id"` + ClusterName string `json:"cluster_name"` + Auth discoveryAuth `json:"auth"` + JWKSURI string `json:"jwks_uri"` +} + +type discoveryAuth struct { + Issuer string `json:"issuer"` + Audience string `json:"audience"` + TokenEndpoint string `json:"token_endpoint"` + RevokeEndpoint string `json:"revoke_endpoint"` + Modes []string `json:"modes"` +} + +type jwksDocument struct { + Keys []jwk `json:"keys"` +} + +type jwk struct { + Kty string `json:"kty"` + Alg string `json:"alg"` + Use string `json:"use"` + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` +} + +func (s *Server) DiscoveryEndpoint(w http.ResponseWriter, r *http.Request) { + writeJSON(w, discoveryDocument{ + Auth: discoveryAuth{ + Issuer: s.issuer, + Audience: s.clientID, + TokenEndpoint: tokenEndpoint, + RevokeEndpoint: revokeEndpoint, + Modes: []string{"password"}, + }, + JWKSURI: jwksURI, + }) +} + +func (s *Server) JWKSEndpoint(w http.ResponseWriter, r *http.Request) { + pub := s.publicKey + w.Header().Set("Cache-Control", "max-age=300") + writeJSON(w, jwksDocument{Keys: []jwk{{ + Kty: "RSA", + Alg: "RS256", + Use: "sig", + Kid: s.keyID, + N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(pub.E)).Bytes()), + }}}) +} + +func writeJSON(w http.ResponseWriter, value any) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(value); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/pkg/auth/discovery_handler_test.go b/pkg/auth/discovery_handler_test.go new file mode 100644 index 0000000..c19939c --- /dev/null +++ b/pkg/auth/discovery_handler_test.go @@ -0,0 +1,101 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDiscoveryEndpoint(t *testing.T) { + server := newTestServer(t) + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/ceph-api", nil) + + server.DiscoveryEndpoint(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK) + } + if got, want := recorder.Header().Get("Content-Type"), "application/json"; got != want { + t.Fatalf("Content-Type = %q, want %q", got, want) + } + + var doc discoveryDocument + if err := json.Unmarshal(recorder.Body.Bytes(), &doc); err != nil { + t.Fatalf("decode response: %v", err) + } + if doc.Auth.Issuer != "http://issuer.example" { + t.Fatalf("issuer = %q", doc.Auth.Issuer) + } + if doc.Auth.Audience != "ceph-api" { + t.Fatalf("audience = %q", doc.Auth.Audience) + } + if doc.Auth.TokenEndpoint != tokenEndpoint { + t.Fatalf("token endpoint = %q", doc.Auth.TokenEndpoint) + } + if doc.Auth.RevokeEndpoint != revokeEndpoint { + t.Fatalf("revoke endpoint = %q", doc.Auth.RevokeEndpoint) + } + if doc.JWKSURI != jwksURI { + t.Fatalf("jwks uri = %q", doc.JWKSURI) + } + if len(doc.Auth.Modes) != 1 || doc.Auth.Modes[0] != "password" { + t.Fatalf("modes = %v", doc.Auth.Modes) + } +} + +func TestJWKSEndpoint(t *testing.T) { + server := newTestServer(t) + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/ceph-api/jwks.json", nil) + + server.JWKSEndpoint(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK) + } + if got, want := recorder.Header().Get("Cache-Control"), "max-age=300"; got != want { + t.Fatalf("Cache-Control = %q, want %q", got, want) + } + if got, want := recorder.Header().Get("Content-Type"), "application/json"; got != want { + t.Fatalf("Content-Type = %q, want %q", got, want) + } + + var doc jwksDocument + if err := json.Unmarshal(recorder.Body.Bytes(), &doc); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(doc.Keys) != 1 { + t.Fatalf("keys len = %d, want 1", len(doc.Keys)) + } + key := doc.Keys[0] + if key.Kty != "RSA" || key.Alg != "RS256" || key.Use != "sig" { + t.Fatalf("unexpected key metadata: %+v", key) + } + if key.Kid != server.keyID { + t.Fatalf("kid = %q, want %q", key.Kid, server.keyID) + } + if key.N == "" || key.E == "" { + t.Fatalf("empty key material: %+v", key) + } +} + +func newTestServer(t *testing.T) *Server { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + kid, err := computeKID(&priv.PublicKey) + if err != nil { + t.Fatalf("compute kid: %v", err) + } + server, err := NewServer(Config{ClientID: "ceph-api", Issuer: "http://issuer.example"}, nil, priv, kid) + if err != nil { + t.Fatalf("NewServer() error = %v", err) + } + return server +}