Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions pkg/api/grpc_http_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ import (
"google.golang.org/grpc/credentials/insecure"
)

func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFunc, oauthHandlers map[string]http.HandlerFunc) (http.Handler, error) {
type HTTPRoute struct {
Method string
Path string
Handler http.HandlerFunc
}

func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFunc, routes []HTTPRoute) (http.Handler, error) {
mux := runtime.NewServeMux()
var opts []grpc.DialOption

Expand Down Expand Up @@ -53,9 +59,8 @@ func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFu
if metricsHandler != nil {
handleGET(mux, "/metrics", metricsHandler)
}
// Register fosite oauth endpoints
for path, h := range oauthHandlers {
handlePOST(mux, path, h)
for _, route := range routes {
handleRoute(mux, route)
}

if conf.ServeDebug {
Expand Down Expand Up @@ -88,19 +93,14 @@ func GRPCGateway(ctx context.Context, conf Config, metricsHandler http.HandlerFu
}

func handleGET(mux *runtime.ServeMux, path string, handler http.HandlerFunc) {
err := mux.HandlePath("GET", path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
handler(w, r)
})
if err != nil {
panic(fmt.Errorf("%w: unable to register http handler %s", err, path))
}
handleRoute(mux, HTTPRoute{Method: http.MethodGet, Path: path, Handler: handler})
}

func handlePOST(mux *runtime.ServeMux, path string, handler http.HandlerFunc) {
err := mux.HandlePath("POST", path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
handler(w, r)
func handleRoute(mux *runtime.ServeMux, route HTTPRoute) {
err := mux.HandlePath(route.Method, route.Path, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
route.Handler(w, r)
})
if err != nil {
panic(fmt.Errorf("%w: unable to register http handler %s", err, path))
panic(fmt.Errorf("%w: unable to register %s http handler %s", err, route.Method, route.Path))
}
}
14 changes: 8 additions & 6 deletions pkg/app/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ func Start(ctx context.Context, conf config.Config, build config.Build) error {
if conf.Metrics.Enabled {
metricsHandler = promhttp.Handler().ServeHTTP
}
oauthHandlers := map[string]http.HandlerFunc{
"/api/oauth/token": authServer.TokenEndpoint,
"/api/oauth/auth": authServer.AuthEndpoint,
"/api/oauth/revoke": authServer.RevokeEndpoint,
"/api/oauth/introspect": authServer.IntrospectionEndpoint,
httpRoutes := []api.HTTPRoute{
{Method: http.MethodPost, Path: "/api/oauth/token", Handler: authServer.TokenEndpoint},
{Method: http.MethodPost, Path: "/api/oauth/auth", Handler: authServer.AuthEndpoint},
{Method: http.MethodPost, Path: "/api/oauth/revoke", Handler: authServer.RevokeEndpoint},
{Method: http.MethodPost, Path: "/api/oauth/introspect", Handler: authServer.IntrospectionEndpoint},
{Method: http.MethodGet, Path: "/.well-known/ceph-api", Handler: authServer.DiscoveryEndpoint},
{Method: http.MethodGet, Path: "/.well-known/ceph-api/jwks.json", Handler: authServer.JWKSEndpoint},
}
httpServer, err := api.GRPCGateway(ctx, conf.Api, metricsHandler, oauthHandlers)
httpServer, err := api.GRPCGateway(ctx, conf.Api, metricsHandler, httpRoutes)
if err != nil {
return err
}
Expand Down
75 changes: 75 additions & 0 deletions pkg/auth/discovery_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package auth

import (
"encoding/base64"
"encoding/json"
"math/big"
"net/http"
)

const (
tokenEndpoint = "/api/oauth/token"
revokeEndpoint = "/api/oauth/revoke"
jwksURI = "/.well-known/ceph-api/jwks.json"
)

type discoveryDocument struct {
ClusterID string `json:"cluster_id"`
ClusterName string `json:"cluster_name"`
Auth discoveryAuth `json:"auth"`
JWKSURI string `json:"jwks_uri"`
}

type discoveryAuth struct {
Issuer string `json:"issuer"`
Audience string `json:"audience"`
TokenEndpoint string `json:"token_endpoint"`
RevokeEndpoint string `json:"revoke_endpoint"`
Modes []string `json:"modes"`
}

type jwksDocument struct {
Keys []jwk `json:"keys"`
}

type jwk struct {
Kty string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
}

func (s *Server) DiscoveryEndpoint(w http.ResponseWriter, r *http.Request) {
writeJSON(w, discoveryDocument{
Auth: discoveryAuth{
Issuer: s.issuer,
Audience: s.clientID,
TokenEndpoint: tokenEndpoint,
RevokeEndpoint: revokeEndpoint,
Modes: []string{"password"},
},
JWKSURI: jwksURI,
})
}

func (s *Server) JWKSEndpoint(w http.ResponseWriter, r *http.Request) {
pub := s.publicKey
w.Header().Set("Cache-Control", "max-age=300")
writeJSON(w, jwksDocument{Keys: []jwk{{
Kty: "RSA",
Alg: "RS256",
Use: "sig",
Kid: s.keyID,
N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(pub.E)).Bytes()),
}}})
}

func writeJSON(w http.ResponseWriter, value any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(value); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
101 changes: 101 additions & 0 deletions pkg/auth/discovery_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package auth

import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)

func TestDiscoveryEndpoint(t *testing.T) {
server := newTestServer(t)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/.well-known/ceph-api", nil)

server.DiscoveryEndpoint(recorder, req)

if recorder.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK)
}
if got, want := recorder.Header().Get("Content-Type"), "application/json"; got != want {
t.Fatalf("Content-Type = %q, want %q", got, want)
}

var doc discoveryDocument
if err := json.Unmarshal(recorder.Body.Bytes(), &doc); err != nil {
t.Fatalf("decode response: %v", err)
}
if doc.Auth.Issuer != "http://issuer.example" {
t.Fatalf("issuer = %q", doc.Auth.Issuer)
}
if doc.Auth.Audience != "ceph-api" {
t.Fatalf("audience = %q", doc.Auth.Audience)
}
if doc.Auth.TokenEndpoint != tokenEndpoint {
t.Fatalf("token endpoint = %q", doc.Auth.TokenEndpoint)
}
if doc.Auth.RevokeEndpoint != revokeEndpoint {
t.Fatalf("revoke endpoint = %q", doc.Auth.RevokeEndpoint)
}
if doc.JWKSURI != jwksURI {
t.Fatalf("jwks uri = %q", doc.JWKSURI)
}
if len(doc.Auth.Modes) != 1 || doc.Auth.Modes[0] != "password" {
t.Fatalf("modes = %v", doc.Auth.Modes)
}
}

func TestJWKSEndpoint(t *testing.T) {
server := newTestServer(t)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/.well-known/ceph-api/jwks.json", nil)

server.JWKSEndpoint(recorder, req)

if recorder.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK)
}
if got, want := recorder.Header().Get("Cache-Control"), "max-age=300"; got != want {
t.Fatalf("Cache-Control = %q, want %q", got, want)
}
if got, want := recorder.Header().Get("Content-Type"), "application/json"; got != want {
t.Fatalf("Content-Type = %q, want %q", got, want)
}

var doc jwksDocument
if err := json.Unmarshal(recorder.Body.Bytes(), &doc); err != nil {
t.Fatalf("decode response: %v", err)
}
if len(doc.Keys) != 1 {
t.Fatalf("keys len = %d, want 1", len(doc.Keys))
}
key := doc.Keys[0]
if key.Kty != "RSA" || key.Alg != "RS256" || key.Use != "sig" {
t.Fatalf("unexpected key metadata: %+v", key)
}
if key.Kid != server.keyID {
t.Fatalf("kid = %q, want %q", key.Kid, server.keyID)
}
if key.N == "" || key.E == "" {
t.Fatalf("empty key material: %+v", key)
}
}

func newTestServer(t *testing.T) *Server {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate key: %v", err)
}
kid, err := computeKID(&priv.PublicKey)
if err != nil {
t.Fatalf("compute kid: %v", err)
}
server, err := NewServer(Config{ClientID: "ceph-api", Issuer: "http://issuer.example"}, nil, priv, kid)
if err != nil {
t.Fatalf("NewServer() error = %v", err)
}
return server
}