Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pkg/app/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,17 @@ func Start(ctx context.Context, conf config.Config, build config.Build) error {
}
usersAPI := api.NewUsersAPI(userSvc)

authServer, err := auth.NewServer(conf.Auth, userSvc)
keyStore := auth.NewKeyStore(radosSvc)
signingKey, signingKID, err := keyStore.LoadOrCreate(ctx)
if err != nil {
return fmt.Errorf("load JWT signing key: %w", err)
}
globalSecret, err := keyStore.LoadOrCreateGlobalSecret(ctx)
if err != nil {
return fmt.Errorf("load OAuth global secret: %w", err)
}

authServer, err := auth.NewServer(conf.Auth, userSvc, signingKey, signingKID, globalSecret)
if err != nil {
return err
}
Expand Down
202 changes: 202 additions & 0 deletions pkg/auth/keystore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package auth

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"

"github.com/clyso/ceph-api/pkg/types"
)

const (
keyJWTActive = "ceph-api/auth/jwt-key/active"
keyGlobalSecretActive = "ceph-api/auth/global-secret/active"
globalSecretSize = 32
)

type MonCommander interface {
ExecMon(ctx context.Context, cmd string) ([]byte, error)
ExecMonWithInputBuff(ctx context.Context, cmd string, in []byte) ([]byte, error)
}

type KeyStore struct {
mon MonCommander
}

type persistedJWTKey struct {
KID string `json:"kid"`
PrivateDER string `json:"private_der"`
}

type persistedGlobalSecret struct {
Secret string `json:"secret"`
}

type keyStoreEnvelope struct {
Version uint64 `json:"version"`
Value json.RawMessage `json:"value"`
}

func NewKeyStore(mon MonCommander) *KeyStore {
return &KeyStore{mon: mon}
}

func (k *KeyStore) LoadOrCreate(ctx context.Context) (*rsa.PrivateKey, string, error) {
priv, kid, err := k.loadJWTKey(ctx)
if err == nil {
return priv, kid, nil
}
if !errors.Is(err, types.RadosErrorNotFound) {
return nil, "", err
}

priv, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, "", fmt.Errorf("generate JWT signing key: %w", err)
}
kid, err = computeKID(&priv.PublicKey)
if err != nil {
return nil, "", err
}
// TODO: Re-load after first persist to detect concurrent starters racing on config-key set.
if err := k.storeJWTKey(ctx, priv, kid); err != nil {
return nil, "", err
}
return priv, kid, nil
}

func (k *KeyStore) LoadOrCreateGlobalSecret(ctx context.Context) ([]byte, error) {
secret, err := k.loadGlobalSecret(ctx)
if err == nil {
return secret, nil
}
if !errors.Is(err, types.RadosErrorNotFound) {
return nil, err
}

secret = make([]byte, globalSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("generate OAuth global secret: %w", err)
}
// TODO: Re-load after first persist to detect concurrent starters racing on config-key set.
if err := k.storeGlobalSecret(ctx, secret); err != nil {
return nil, err
}
return secret, nil
}

// storeEnvelope marshals value into the versioned keyStoreEnvelope and writes
// it to the given config-key. label is only used to contextualize errors.
func storeEnvelope(ctx context.Context, mon MonCommander, key string, value any, label string) error {
rec, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("encode %s: %w", label, err)
}
env, err := json.Marshal(keyStoreEnvelope{Version: 1, Value: json.RawMessage(rec)})
if err != nil {
return fmt.Errorf("encode %s envelope: %w", label, err)
}
cmd, err := json.Marshal(map[string]string{"prefix": "config-key set", "key": key})
if err != nil {
return err
}
if _, err := mon.ExecMonWithInputBuff(ctx, string(cmd), env); err != nil {
return fmt.Errorf("persist %s: %w", label, err)
}
return nil
}

func loadEnvelope(ctx context.Context, mon MonCommander, key string, out any, label string) error {
cmd, err := json.Marshal(map[string]string{"prefix": "config-key get", "key": key})
if err != nil {
return err
}
raw, err := mon.ExecMon(ctx, string(cmd))
if err != nil {
return err
}
return decodeEnvelope(raw, out, label)
}

func decodeEnvelope(raw []byte, out any, label string) error {
var env keyStoreEnvelope
if err := json.Unmarshal(raw, &env); err != nil {
return fmt.Errorf("decode persisted %s envelope: %w", label, err)
}
if env.Version == 0 || len(env.Value) == 0 {
return fmt.Errorf("decode persisted %s envelope: missing value", label)
}
if err := json.Unmarshal(env.Value, out); err != nil {
return fmt.Errorf("decode persisted %s: %w", label, err)
}
return nil
}

func (k *KeyStore) loadJWTKey(ctx context.Context) (*rsa.PrivateKey, string, error) {
var rec persistedJWTKey
if err := loadEnvelope(ctx, k.mon, keyJWTActive, &rec, "JWT signing key"); err != nil {
return nil, "", err
}
der, err := base64.RawStdEncoding.DecodeString(rec.PrivateDER)
if err != nil {
return nil, "", fmt.Errorf("decode persisted JWT signing key DER: %w", err)
}
priv, err := x509.ParsePKCS1PrivateKey(der)
if err != nil {
return nil, "", fmt.Errorf("parse persisted JWT signing key: %w", err)
}
kid, err := computeKID(&priv.PublicKey)
if err != nil {
return nil, "", err
}
if rec.KID != "" && rec.KID != kid {
return nil, "", fmt.Errorf("persisted JWT signing key kid mismatch")
}
return priv, kid, nil
}

func (k *KeyStore) storeJWTKey(ctx context.Context, priv *rsa.PrivateKey, kid string) error {
return storeEnvelope(ctx, k.mon, keyJWTActive, persistedJWTKey{
KID: kid,
PrivateDER: base64.RawStdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(priv)),
}, "JWT signing key")
}

func (k *KeyStore) loadGlobalSecret(ctx context.Context) ([]byte, error) {
var rec persistedGlobalSecret
if err := loadEnvelope(ctx, k.mon, keyGlobalSecretActive, &rec, "OAuth global secret"); err != nil {
return nil, err
}
secret, err := base64.RawStdEncoding.DecodeString(rec.Secret)
if err != nil {
return nil, fmt.Errorf("decode persisted OAuth global secret value: %w", err)
}
if len(secret) != globalSecretSize {
return nil, fmt.Errorf("decode persisted OAuth global secret: invalid size")
}
return secret, nil
}

func (k *KeyStore) storeGlobalSecret(ctx context.Context, secret []byte) error {
return storeEnvelope(ctx, k.mon, keyGlobalSecretActive, persistedGlobalSecret{
Secret: base64.RawStdEncoding.EncodeToString(secret),
}, "OAuth global secret")
}

func computeKID(pub *rsa.PublicKey) (string, error) {
h := sha256.New()
if _, err := h.Write(pub.N.Bytes()); err != nil {
return "", err
}
if err := binary.Write(h, binary.BigEndian, int64(pub.E)); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(h.Sum(nil)[:16]), nil
}
140 changes: 140 additions & 0 deletions pkg/auth/keystore_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package auth

import (
"context"
"encoding/json"
"errors"
"testing"

"github.com/clyso/ceph-api/pkg/types"
)

func TestKeyStoreLoadOrCreateCreatesAndReloadsKey(t *testing.T) {
ctx := context.Background()
mon := newFakeMonCommander()
store := NewKeyStore(mon)

priv, kid, err := store.LoadOrCreate(ctx)
if err != nil {
t.Fatalf("LoadOrCreate() error = %v", err)
}
if priv == nil {
t.Fatal("LoadOrCreate() private key is nil")
}
if kid == "" {
t.Fatal("LoadOrCreate() kid is empty")
}
if mon.sets != 1 {
t.Fatalf("config-key set count = %d, want 1", mon.sets)
}

reloadedPriv, reloadedKID, err := store.LoadOrCreate(ctx)
if err != nil {
t.Fatalf("second LoadOrCreate() error = %v", err)
}
if reloadedKID != kid {
t.Fatalf("reloaded kid = %q, want %q", reloadedKID, kid)
}
if reloadedPriv.N.Cmp(priv.N) != 0 || reloadedPriv.E != priv.E || reloadedPriv.D.Cmp(priv.D) != 0 {
t.Fatal("reloaded private key does not match persisted key")
}
if mon.sets != 1 {
t.Fatalf("config-key set count after reload = %d, want 1", mon.sets)
}
}

func TestKeyStoreLoadOrCreateGlobalSecretCreatesAndReloadsSecret(t *testing.T) {
ctx := context.Background()
mon := newFakeMonCommander()
store := NewKeyStore(mon)

secret, err := store.LoadOrCreateGlobalSecret(ctx)
if err != nil {
t.Fatalf("LoadOrCreateGlobalSecret() error = %v", err)
}
if len(secret) != globalSecretSize {
t.Fatalf("global secret len = %d, want %d", len(secret), globalSecretSize)
}
if mon.sets != 1 {
t.Fatalf("config-key set count = %d, want 1", mon.sets)
}

reloadedSecret, err := store.LoadOrCreateGlobalSecret(ctx)
if err != nil {
t.Fatalf("second LoadOrCreateGlobalSecret() error = %v", err)
}
if string(reloadedSecret) != string(secret) {
t.Fatal("reloaded global secret does not match persisted secret")
}
if mon.sets != 1 {
t.Fatalf("config-key set count after reload = %d, want 1", mon.sets)
}
}

func TestNewServerUsesProvidedKID(t *testing.T) {
ctx := context.Background()
store := NewKeyStore(newFakeMonCommander())
priv, kid, err := store.LoadOrCreate(ctx)
if err != nil {
t.Fatalf("LoadOrCreate() error = %v", err)
}

globalSecret, err := store.LoadOrCreateGlobalSecret(ctx)
if err != nil {
t.Fatalf("LoadOrCreateGlobalSecret() error = %v", err)
}

server, err := NewServer(Config{ClientID: "ceph-api", Issuer: "test"}, nil, priv, kid, globalSecret)
if err != nil {
t.Fatalf("NewServer() error = %v", err)
}

jwtSession := server.newSession("admin", nil)
got, _ := jwtSession.GetJWTHeader().Extra["kid"].(string)
if got != kid {
t.Fatalf("JWT header kid = %q, want %q", got, kid)
}
}

type fakeMonCommander struct {
values map[string][]byte
sets int
}

func newFakeMonCommander() *fakeMonCommander {
return &fakeMonCommander{values: map[string][]byte{}}
}

func (f *fakeMonCommander) ExecMon(_ context.Context, cmd string) ([]byte, error) {
var req struct {
Prefix string `json:"prefix"`
Key string `json:"key"`
}
if err := json.Unmarshal([]byte(cmd), &req); err != nil {
return nil, err
}
if req.Prefix != "config-key get" {
return nil, errors.New("unexpected ExecMon command")
}
value, ok := f.values[req.Key]
if !ok {
return nil, types.RadosErrorNotFound
}
return append([]byte(nil), value...), nil
}

func (f *fakeMonCommander) ExecMonWithInputBuff(_ context.Context, cmd string, in []byte) ([]byte, error) {
var req struct {
Prefix string `json:"prefix"`
Key string `json:"key"`
}
if err := json.Unmarshal([]byte(cmd), &req); err != nil {
return nil, err
}
if req.Prefix != "config-key set" {
return nil, errors.New("unexpected ExecMonWithInputBuff command")
}
f.values[req.Key] = append([]byte(nil), in...)
f.sets++
return nil, nil
}
Loading