diff --git a/pkg/app/start.go b/pkg/app/start.go index 79a050d..0c219a0 100644 --- a/pkg/app/start.go +++ b/pkg/app/start.go @@ -91,7 +91,17 @@ func Start(ctx context.Context, conf config.Config, build config.Build) error { } usersAPI := api.NewUsersAPI(userSvc) - authServer, err := auth.NewServer(conf.Auth, userSvc) + keyStore := auth.NewKeyStore(radosSvc) + signingKey, signingKID, err := keyStore.LoadOrCreate(ctx) + if err != nil { + return fmt.Errorf("load JWT signing key: %w", err) + } + globalSecret, err := keyStore.LoadOrCreateGlobalSecret(ctx) + if err != nil { + return fmt.Errorf("load OAuth global secret: %w", err) + } + + authServer, err := auth.NewServer(conf.Auth, userSvc, signingKey, signingKID, globalSecret) if err != nil { return err } diff --git a/pkg/auth/keystore.go b/pkg/auth/keystore.go new file mode 100644 index 0000000..5aa64a6 --- /dev/null +++ b/pkg/auth/keystore.go @@ -0,0 +1,202 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + + "github.com/clyso/ceph-api/pkg/types" +) + +const ( + keyJWTActive = "ceph-api/auth/jwt-key/active" + keyGlobalSecretActive = "ceph-api/auth/global-secret/active" + globalSecretSize = 32 +) + +type MonCommander interface { + ExecMon(ctx context.Context, cmd string) ([]byte, error) + ExecMonWithInputBuff(ctx context.Context, cmd string, in []byte) ([]byte, error) +} + +type KeyStore struct { + mon MonCommander +} + +type persistedJWTKey struct { + KID string `json:"kid"` + PrivateDER string `json:"private_der"` +} + +type persistedGlobalSecret struct { + Secret string `json:"secret"` +} + +type keyStoreEnvelope struct { + Version uint64 `json:"version"` + Value json.RawMessage `json:"value"` +} + +func NewKeyStore(mon MonCommander) *KeyStore { + return &KeyStore{mon: mon} +} + +func (k *KeyStore) LoadOrCreate(ctx context.Context) (*rsa.PrivateKey, string, error) { + priv, kid, err := k.loadJWTKey(ctx) + if err == nil { + return priv, kid, nil + } + if !errors.Is(err, types.RadosErrorNotFound) { + return nil, "", err + } + + priv, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, "", fmt.Errorf("generate JWT signing key: %w", err) + } + kid, err = computeKID(&priv.PublicKey) + if err != nil { + return nil, "", err + } + // TODO: Re-load after first persist to detect concurrent starters racing on config-key set. + if err := k.storeJWTKey(ctx, priv, kid); err != nil { + return nil, "", err + } + return priv, kid, nil +} + +func (k *KeyStore) LoadOrCreateGlobalSecret(ctx context.Context) ([]byte, error) { + secret, err := k.loadGlobalSecret(ctx) + if err == nil { + return secret, nil + } + if !errors.Is(err, types.RadosErrorNotFound) { + return nil, err + } + + secret = make([]byte, globalSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, fmt.Errorf("generate OAuth global secret: %w", err) + } + // TODO: Re-load after first persist to detect concurrent starters racing on config-key set. + if err := k.storeGlobalSecret(ctx, secret); err != nil { + return nil, err + } + return secret, nil +} + +// storeEnvelope marshals value into the versioned keyStoreEnvelope and writes +// it to the given config-key. label is only used to contextualize errors. +func storeEnvelope(ctx context.Context, mon MonCommander, key string, value any, label string) error { + rec, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("encode %s: %w", label, err) + } + env, err := json.Marshal(keyStoreEnvelope{Version: 1, Value: json.RawMessage(rec)}) + if err != nil { + return fmt.Errorf("encode %s envelope: %w", label, err) + } + cmd, err := json.Marshal(map[string]string{"prefix": "config-key set", "key": key}) + if err != nil { + return err + } + if _, err := mon.ExecMonWithInputBuff(ctx, string(cmd), env); err != nil { + return fmt.Errorf("persist %s: %w", label, err) + } + return nil +} + +func loadEnvelope(ctx context.Context, mon MonCommander, key string, out any, label string) error { + cmd, err := json.Marshal(map[string]string{"prefix": "config-key get", "key": key}) + if err != nil { + return err + } + raw, err := mon.ExecMon(ctx, string(cmd)) + if err != nil { + return err + } + return decodeEnvelope(raw, out, label) +} + +func decodeEnvelope(raw []byte, out any, label string) error { + var env keyStoreEnvelope + if err := json.Unmarshal(raw, &env); err != nil { + return fmt.Errorf("decode persisted %s envelope: %w", label, err) + } + if env.Version == 0 || len(env.Value) == 0 { + return fmt.Errorf("decode persisted %s envelope: missing value", label) + } + if err := json.Unmarshal(env.Value, out); err != nil { + return fmt.Errorf("decode persisted %s: %w", label, err) + } + return nil +} + +func (k *KeyStore) loadJWTKey(ctx context.Context) (*rsa.PrivateKey, string, error) { + var rec persistedJWTKey + if err := loadEnvelope(ctx, k.mon, keyJWTActive, &rec, "JWT signing key"); err != nil { + return nil, "", err + } + der, err := base64.RawStdEncoding.DecodeString(rec.PrivateDER) + if err != nil { + return nil, "", fmt.Errorf("decode persisted JWT signing key DER: %w", err) + } + priv, err := x509.ParsePKCS1PrivateKey(der) + if err != nil { + return nil, "", fmt.Errorf("parse persisted JWT signing key: %w", err) + } + kid, err := computeKID(&priv.PublicKey) + if err != nil { + return nil, "", err + } + if rec.KID != "" && rec.KID != kid { + return nil, "", fmt.Errorf("persisted JWT signing key kid mismatch") + } + return priv, kid, nil +} + +func (k *KeyStore) storeJWTKey(ctx context.Context, priv *rsa.PrivateKey, kid string) error { + return storeEnvelope(ctx, k.mon, keyJWTActive, persistedJWTKey{ + KID: kid, + PrivateDER: base64.RawStdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(priv)), + }, "JWT signing key") +} + +func (k *KeyStore) loadGlobalSecret(ctx context.Context) ([]byte, error) { + var rec persistedGlobalSecret + if err := loadEnvelope(ctx, k.mon, keyGlobalSecretActive, &rec, "OAuth global secret"); err != nil { + return nil, err + } + secret, err := base64.RawStdEncoding.DecodeString(rec.Secret) + if err != nil { + return nil, fmt.Errorf("decode persisted OAuth global secret value: %w", err) + } + if len(secret) != globalSecretSize { + return nil, fmt.Errorf("decode persisted OAuth global secret: invalid size") + } + return secret, nil +} + +func (k *KeyStore) storeGlobalSecret(ctx context.Context, secret []byte) error { + return storeEnvelope(ctx, k.mon, keyGlobalSecretActive, persistedGlobalSecret{ + Secret: base64.RawStdEncoding.EncodeToString(secret), + }, "OAuth global secret") +} + +func computeKID(pub *rsa.PublicKey) (string, error) { + h := sha256.New() + if _, err := h.Write(pub.N.Bytes()); err != nil { + return "", err + } + if err := binary.Write(h, binary.BigEndian, int64(pub.E)); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(h.Sum(nil)[:16]), nil +} diff --git a/pkg/auth/keystore_test.go b/pkg/auth/keystore_test.go new file mode 100644 index 0000000..696ea9f --- /dev/null +++ b/pkg/auth/keystore_test.go @@ -0,0 +1,140 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/clyso/ceph-api/pkg/types" +) + +func TestKeyStoreLoadOrCreateCreatesAndReloadsKey(t *testing.T) { + ctx := context.Background() + mon := newFakeMonCommander() + store := NewKeyStore(mon) + + priv, kid, err := store.LoadOrCreate(ctx) + if err != nil { + t.Fatalf("LoadOrCreate() error = %v", err) + } + if priv == nil { + t.Fatal("LoadOrCreate() private key is nil") + } + if kid == "" { + t.Fatal("LoadOrCreate() kid is empty") + } + if mon.sets != 1 { + t.Fatalf("config-key set count = %d, want 1", mon.sets) + } + + reloadedPriv, reloadedKID, err := store.LoadOrCreate(ctx) + if err != nil { + t.Fatalf("second LoadOrCreate() error = %v", err) + } + if reloadedKID != kid { + t.Fatalf("reloaded kid = %q, want %q", reloadedKID, kid) + } + if reloadedPriv.N.Cmp(priv.N) != 0 || reloadedPriv.E != priv.E || reloadedPriv.D.Cmp(priv.D) != 0 { + t.Fatal("reloaded private key does not match persisted key") + } + if mon.sets != 1 { + t.Fatalf("config-key set count after reload = %d, want 1", mon.sets) + } +} + +func TestKeyStoreLoadOrCreateGlobalSecretCreatesAndReloadsSecret(t *testing.T) { + ctx := context.Background() + mon := newFakeMonCommander() + store := NewKeyStore(mon) + + secret, err := store.LoadOrCreateGlobalSecret(ctx) + if err != nil { + t.Fatalf("LoadOrCreateGlobalSecret() error = %v", err) + } + if len(secret) != globalSecretSize { + t.Fatalf("global secret len = %d, want %d", len(secret), globalSecretSize) + } + if mon.sets != 1 { + t.Fatalf("config-key set count = %d, want 1", mon.sets) + } + + reloadedSecret, err := store.LoadOrCreateGlobalSecret(ctx) + if err != nil { + t.Fatalf("second LoadOrCreateGlobalSecret() error = %v", err) + } + if string(reloadedSecret) != string(secret) { + t.Fatal("reloaded global secret does not match persisted secret") + } + if mon.sets != 1 { + t.Fatalf("config-key set count after reload = %d, want 1", mon.sets) + } +} + +func TestNewServerUsesProvidedKID(t *testing.T) { + ctx := context.Background() + store := NewKeyStore(newFakeMonCommander()) + priv, kid, err := store.LoadOrCreate(ctx) + if err != nil { + t.Fatalf("LoadOrCreate() error = %v", err) + } + + globalSecret, err := store.LoadOrCreateGlobalSecret(ctx) + if err != nil { + t.Fatalf("LoadOrCreateGlobalSecret() error = %v", err) + } + + server, err := NewServer(Config{ClientID: "ceph-api", Issuer: "test"}, nil, priv, kid, globalSecret) + if err != nil { + t.Fatalf("NewServer() error = %v", err) + } + + jwtSession := server.newSession("admin", nil) + got, _ := jwtSession.GetJWTHeader().Extra["kid"].(string) + if got != kid { + t.Fatalf("JWT header kid = %q, want %q", got, kid) + } +} + +type fakeMonCommander struct { + values map[string][]byte + sets int +} + +func newFakeMonCommander() *fakeMonCommander { + return &fakeMonCommander{values: map[string][]byte{}} +} + +func (f *fakeMonCommander) ExecMon(_ context.Context, cmd string) ([]byte, error) { + var req struct { + Prefix string `json:"prefix"` + Key string `json:"key"` + } + if err := json.Unmarshal([]byte(cmd), &req); err != nil { + return nil, err + } + if req.Prefix != "config-key get" { + return nil, errors.New("unexpected ExecMon command") + } + value, ok := f.values[req.Key] + if !ok { + return nil, types.RadosErrorNotFound + } + return append([]byte(nil), value...), nil +} + +func (f *fakeMonCommander) ExecMonWithInputBuff(_ context.Context, cmd string, in []byte) ([]byte, error) { + var req struct { + Prefix string `json:"prefix"` + Key string `json:"key"` + } + if err := json.Unmarshal([]byte(cmd), &req); err != nil { + return nil, err + } + if req.Prefix != "config-key set" { + return nil, errors.New("unexpected ExecMonWithInputBuff command") + } + f.values[req.Key] = append([]byte(nil), in...) + f.sets++ + return nil, nil +} diff --git a/pkg/auth/oauth_server.go b/pkg/auth/oauth_server.go index 7350254..f09719b 100644 --- a/pkg/auth/oauth_server.go +++ b/pkg/auth/oauth_server.go @@ -2,8 +2,8 @@ package auth import ( "context" - "crypto/rand" "crypto/rsa" + "fmt" "time" "github.com/clyso/ceph-api/pkg/user" @@ -32,17 +32,17 @@ type Server struct { userSvc *user.Service } -func NewServer(config Config, userSvc *user.Service) (*Server, error) { - var secret = make([]byte, 32) - _, err := rand.Read(secret) - if err != nil { - return nil, err +func NewServer(config Config, userSvc *user.Service, privateKey *rsa.PrivateKey, kid string, globalSecret []byte) (*Server, error) { + if privateKey == nil { + return nil, fmt.Errorf("JWT signing key is required") } - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err + if kid == "" { + return nil, fmt.Errorf("JWT signing key id is required") + } + if len(globalSecret) < globalSecretSize { + return nil, fmt.Errorf("OAuth global secret must be at least %d bytes", globalSecretSize) } + secret := append([]byte(nil), globalSecret...) defaultStor := storage.NewMemoryStore() defaultStor.Clients = map[string]fosite.Client{ @@ -88,7 +88,7 @@ func NewServer(config Config, userSvc *user.Service) (*Server, error) { return &Server{ publicKey: &privateKey.PublicKey, - keyID: "0", + keyID: kid, issuer: config.Issuer, clientID: config.ClientID, refreshTokenLifespan: config.RefreshTokenLifespan, diff --git a/pkg/rados/service.go b/pkg/rados/service.go index 32925db..62c5ef4 100644 --- a/pkg/rados/service.go +++ b/pkg/rados/service.go @@ -33,7 +33,7 @@ func (s *Svc) ExecMon(ctx context.Context, cmd string) ([]byte, error) { func (s *Svc) ExecMonWithInputBuff(ctx context.Context, cmd string, inputBuffer []byte) ([]byte, error) { logger := zerolog.Ctx(ctx).With().Str("mon_cmd", cmd).Logger() - logger.Debug().Str("mon_cmd_buf", string(inputBuffer)).Msg("executing mon command with input buffer") + logger.Debug().Int("mon_cmd_buf_len", len(inputBuffer)).Msg("executing mon command with input buffer") cmdRes, cmdStatus, err := s.conn.MonCommandWithInputBuffer([]byte(cmd), inputBuffer) if err != nil { logger.Err(err).Str("cmd_status", cmdStatus).Msg("mon command with input buffer executed with error")