From 927107b4bca8555b5305baa81e76aafca911d83f Mon Sep 17 00:00:00 2001 From: rgthelen Date: Tue, 7 Jan 2025 13:18:55 -0600 Subject: [PATCH] updates to handle token validation for app varients --- examples/rownd-test/README.md | 4 ++-- pkg/rownd/middleware/http.go | 3 ++- pkg/rownd/middleware/middleware.go | 3 ++- pkg/rownd/token.go | 28 ++++++++++++++++++++++++---- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/examples/rownd-test/README.md b/examples/rownd-test/README.md index 87895f0..2c07421 100644 --- a/examples/rownd-test/README.md +++ b/examples/rownd-test/README.md @@ -1,7 +1,7 @@ ```bash -git clone https://github.com/yourusername/rownd-test.git -cd rownd-test +git clone https://github.com/rownd/client-go.git +cd client-go/examples/rownd-test ``` 2. Create a `.env` file in the project root: diff --git a/pkg/rownd/middleware/http.go b/pkg/rownd/middleware/http.go index 932327d..fb9eac7 100644 --- a/pkg/rownd/middleware/http.go +++ b/pkg/rownd/middleware/http.go @@ -18,7 +18,8 @@ func WithAuthentication(handler Handler) func(next http.Handler) http.Handler { } ctx := r.Context() - validated, err := handler.Validator.Validate(ctx, token) + // Pass validation options to Validate + validated, err := handler.Validator.Validate(ctx, token, handler.ValidationOpts) if err != nil { handler.ErrorHandler(w, r, errors.New("Forbidden")) return diff --git a/pkg/rownd/middleware/middleware.go b/pkg/rownd/middleware/middleware.go index 95e8f36..23fdead 100644 --- a/pkg/rownd/middleware/middleware.go +++ b/pkg/rownd/middleware/middleware.go @@ -19,7 +19,8 @@ type ( type Handler struct { Validator rownd.TokenValidator TokenExtractor TokenExtractor - ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) + ErrorHandler ErrorHandler + ValidationOpts *rownd.TokenValidationOptions } func NewHandler(validator rownd.TokenValidator, opts ...HandlerOption) (*Handler, error) { diff --git a/pkg/rownd/token.go b/pkg/rownd/token.go index 6d476a9..502af2c 100644 --- a/pkg/rownd/token.go +++ b/pkg/rownd/token.go @@ -20,6 +20,11 @@ const ( AuthLevelVerified AuthLevel = "verified" ) +// TokenValidationOptions ... +type TokenValidationOptions struct { + VariantID string +} + // Token ... type Token struct { Token *jwt.Token `json:"-"` // The parsed JWT token @@ -100,7 +105,7 @@ func TokenFromCtx(ctx context.Context) *Token { // TokenValidator ... type TokenValidator interface { - Validate(ctx context.Context, token string) (*Token, error) + Validate(ctx context.Context, token string, opts *TokenValidationOptions) (*Token, error) } type tokenValidator struct { @@ -108,7 +113,7 @@ type tokenValidator struct { } // Validate ... -func (c *tokenValidator) Validate(ctx context.Context, token string) (*Token, error) { +func (c *tokenValidator) Validate(ctx context.Context, token string, opts *TokenValidationOptions) (*Token, error) { if token == "" { return nil, NewError(ErrAuthentication, "invalid token", nil) } @@ -175,6 +180,21 @@ func (c *tokenValidator) Validate(ctx context.Context, token string) (*Token, er return nil, NewError(ErrAuthentication, "invalid token audience", nil) } + // Add variant validation if specified + if opts != nil && opts.VariantID != "" { + expectedVariantAud := fmt.Sprintf("app_variant:%s", opts.VariantID) + hasValidVariant := false + for _, aud := range claims.Aud { + if aud == expectedVariantAud { + hasValidVariant = true + break + } + } + if !hasValidVariant { + return nil, NewError(ErrAuthentication, "invalid token variant audience", nil) + } + } + r := &Token{ Token: parsedToken, Claims: *claims, @@ -186,9 +206,9 @@ func (c *tokenValidator) Validate(ctx context.Context, token string) (*Token, er } // Add this method to expose token validation on the Client -func (c *Client) ValidateToken(ctx context.Context, token string) (*Token, error) { +func (c *Client) ValidateToken(ctx context.Context, token string, opts *TokenValidationOptions) (*Token, error) { validator := &tokenValidator{Client: c} - return validator.Validate(ctx, token) + return validator.Validate(ctx, token, opts) } // Add JWKS types