diff --git a/runs/config/config.go b/runs/config/config.go index c887d9df55..c2b7ab3c53 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -58,10 +58,67 @@ type Config struct { SeedProjects []string `json:"seedProjects" pflag:",Projects to create by default at startup"` // Domains are injected into project responses (not stored per project row). - Domains []DomainConfig `json:"domains"` + // Excluded from pflags (slices of structs are unsupported by the generator); + // configure via the config file only. + Domains []DomainConfig `json:"domains" pflag:"-"` // TriggerScheduler configures the cron-based trigger scheduler worker. TriggerScheduler TriggerSchedulerConfig `json:"triggerScheduler"` + + // AuthMetadata configures the OAuth2 authorization-server metadata endpoint + // (the GetOAuth2Metadata RPC and /.well-known/oauth-authorization-server). + AuthMetadata AuthMetadataConfig `json:"authMetadata"` +} + +// AuthMetadataConfig controls how the runs service serves OAuth2 authorization +// server metadata. When ExternalAuthServerBaseURL is set, the service proxies +// the external authorization server's metadata document (e.g. Okta) so that +// clients discovering auth at this deployment are pointed at the external IdP +// and obtain externally-issued tokens. When empty, GetOAuth2Metadata returns +// Unimplemented (HTTP 501 for the well-known handler). +type AuthMetadataConfig struct { + // ExternalAuthServerBaseURL is the base URL of the external OAuth2 + // authorization server to proxy metadata from + // (e.g. "https://signin.example.com/oauth2/default"). Empty disables the + // endpoint (GetOAuth2Metadata returns Unimplemented). + ExternalAuthServerBaseURL string `json:"externalAuthServerBaseUrl" pflag:",Base URL of the external OAuth2 authorization server to proxy metadata from"` + + // ExternalMetadataURL optionally overrides the metadata path resolved + // against ExternalAuthServerBaseURL. Defaults to + // ".well-known/oauth-authorization-server". + ExternalMetadataURL string `json:"externalMetadataUrl" pflag:",Override for the external metadata path"` + + // RetryAttempts is how many times to try fetching external metadata (default 5). + RetryAttempts int `json:"retryAttempts" pflag:",Attempts to fetch external metadata"` + + // RetryDelay is the delay between fetch attempts (default 1s). + RetryDelay config.Duration `json:"retryDelay" pflag:",Delay between external metadata fetch attempts"` + + // AuthorizationMetadataKey is the header/metadata key clients should place + // tokens in, returned by GetPublicClientConfig (default "authorization"). + AuthorizationMetadataKey string `json:"authorizationMetadataKey" pflag:",Header key clients should use for tokens"` + + // FlyteClient is the public (CLI/SDK) OAuth2 client configuration returned + // by GetPublicClientConfig. + FlyteClient FlyteClientConfig `json:"flyteClient"` +} + +// FlyteClientConfig mirrors flyteadmin's appAuth.thirdPartyConfig.flyteClient: +// the public OAuth2 client (flytectl/pyflyte) settings advertised to SDKs via +// GetPublicClientConfig. +type FlyteClientConfig struct { + // ClientID is the public client id used by CLI/SDK login flows. + ClientID string `json:"clientId" pflag:",Public OAuth2 client id advertised to SDKs"` + + // RedirectURI is the callback the public client listens on during login. + RedirectURI string `json:"redirectUri" pflag:",Redirect URI for the public client login flow"` + + // Scopes are the OAuth2 scopes the public client should request. + Scopes []string `json:"scopes" pflag:",Scopes the public client should request"` + + // Audience is the intended audience for requested tokens (sent when the IdP + // requires it, e.g. Auth0/Okta custom authorization servers). + Audience string `json:"audience" pflag:",Audience for requested tokens"` } // ServerConfig holds HTTP server configuration @@ -82,13 +139,17 @@ type TriggerSchedulerConfig struct { Enabled bool `json:"enabled" pflag:",Enable the trigger scheduler worker"` // ResyncInterval is how often the scheduler re-reads active triggers from the DB. - ResyncInterval time.Duration `json:"resyncInterval" pflag:",How often to resync active triggers from the database"` + // Excluded from pflags (raw time.Duration is unsupported by the generator); + // configure via the config file only. + ResyncInterval time.Duration `json:"resyncInterval" pflag:"-"` // MaxCatchupRunsPerLoop caps how many catchup runs are fired per resync loop. MaxCatchupRunsPerLoop int `json:"maxCatchupRunsPerLoop" pflag:",Maximum catchup runs fired per resync loop"` // ExecutionQPS is the token-bucket rate for CreateRun calls (tokens/second). - ExecutionQPS float64 `json:"executionQps" pflag:",Rate limit for CreateRun calls (requests per second)"` + // Excluded from pflags (float64 is unsupported by the generator); configure + // via the config file only. + ExecutionQPS float64 `json:"executionQps" pflag:"-"` // ExecutionBurst is the token-bucket burst size. ExecutionBurst int `json:"executionBurst" pflag:",Burst size for CreateRun rate limiter"` diff --git a/runs/config/config_flags.go b/runs/config/config_flags.go index 7a5bbc271b..a32e66083d 100755 --- a/runs/config/config_flags.go +++ b/runs/config/config_flags.go @@ -69,5 +69,17 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "actionsServiceUrl"), defaultConfig.ActionsServiceURL, "URL of the actions service") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storagePrefix"), defaultConfig.StoragePrefix, "Base URI prefix for storing run inputs and outputs") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "seedProjects"), defaultConfig.SeedProjects, "Projects to create by default at startup") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "triggerScheduler.enabled"), defaultConfig.TriggerScheduler.Enabled, "Enable the trigger scheduler worker") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "triggerScheduler.maxCatchupRunsPerLoop"), defaultConfig.TriggerScheduler.MaxCatchupRunsPerLoop, "Maximum catchup runs fired per resync loop") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "triggerScheduler.executionBurst"), defaultConfig.TriggerScheduler.ExecutionBurst, "Burst size for CreateRun rate limiter") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.externalAuthServerBaseUrl"), defaultConfig.AuthMetadata.ExternalAuthServerBaseURL, "Base URL of the external OAuth2 authorization server to proxy metadata from") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.externalMetadataUrl"), defaultConfig.AuthMetadata.ExternalMetadataURL, "Override for the external metadata path") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "authMetadata.retryAttempts"), defaultConfig.AuthMetadata.RetryAttempts, "Attempts to fetch external metadata") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.retryDelay"), defaultConfig.AuthMetadata.RetryDelay.String(), "Delay between external metadata fetch attempts") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.authorizationMetadataKey"), defaultConfig.AuthMetadata.AuthorizationMetadataKey, "Header key clients should use for tokens") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.clientId"), defaultConfig.AuthMetadata.FlyteClient.ClientID, "Public OAuth2 client id advertised to SDKs") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.redirectUri"), defaultConfig.AuthMetadata.FlyteClient.RedirectURI, "Redirect URI for the public client login flow") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.scopes"), defaultConfig.AuthMetadata.FlyteClient.Scopes, "Scopes the public client should request") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.audience"), defaultConfig.AuthMetadata.FlyteClient.Audience, "Audience for requested tokens") return cmdFlags } diff --git a/runs/config/config_flags_test.go b/runs/config/config_flags_test.go index c9450dbb1d..936e24a753 100755 --- a/runs/config/config_flags_test.go +++ b/runs/config/config_flags_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - "github.com/mitchellh/mapstructure" + "github.com/go-viper/mapstructure/v2" "github.com/stretchr/testify/assert" ) @@ -354,11 +354,179 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Test_seedProjects", func(t *testing.T) { t.Run("Override", func(t *testing.T) { - testValue := []string{"flytesnacks", "demo"} + testValue := join_Config(defaultConfig.SeedProjects, ",") - cmdFlags.Set("seedProjects", join_Config(testValue, ",")) + cmdFlags.Set("seedProjects", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("seedProjects"); err == nil { - testDecodeRaw_Config(t, vStringSlice, &actual.SeedProjects) + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.SeedProjects) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.enabled", testValue) + if vBool, err := cmdFlags.GetBool("triggerScheduler.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.TriggerScheduler.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.maxCatchupRunsPerLoop", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.maxCatchupRunsPerLoop", testValue) + if vInt, err := cmdFlags.GetInt("triggerScheduler.maxCatchupRunsPerLoop"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.TriggerScheduler.MaxCatchupRunsPerLoop) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.executionBurst", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.executionBurst", testValue) + if vInt, err := cmdFlags.GetInt("triggerScheduler.executionBurst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.TriggerScheduler.ExecutionBurst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.externalAuthServerBaseUrl", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.externalAuthServerBaseUrl", testValue) + if vString, err := cmdFlags.GetString("authMetadata.externalAuthServerBaseUrl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.ExternalAuthServerBaseURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.externalMetadataUrl", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.externalMetadataUrl", testValue) + if vString, err := cmdFlags.GetString("authMetadata.externalMetadataUrl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.ExternalMetadataURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.retryAttempts", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.retryAttempts", testValue) + if vInt, err := cmdFlags.GetInt("authMetadata.retryAttempts"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AuthMetadata.RetryAttempts) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.retryDelay", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.AuthMetadata.RetryDelay.String() + + cmdFlags.Set("authMetadata.retryDelay", testValue) + if vString, err := cmdFlags.GetString("authMetadata.retryDelay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.RetryDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.authorizationMetadataKey", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.authorizationMetadataKey", testValue) + if vString, err := cmdFlags.GetString("authMetadata.authorizationMetadataKey"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.AuthorizationMetadataKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.clientId", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.clientId", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.clientId"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.ClientID) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.redirectUri", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.redirectUri", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.redirectUri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.RedirectURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.scopes", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config(defaultConfig.AuthMetadata.FlyteClient.Scopes, ",") + + cmdFlags.Set("authMetadata.flyteClient.scopes", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("authMetadata.flyteClient.scopes"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.AuthMetadata.FlyteClient.Scopes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.audience", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.audience", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.audience"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.Audience) } else { assert.FailNow(t, err.Error()) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index f27fb450e1..e311fd7054 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -2,33 +2,302 @@ package service import ( "context" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + "sync" + "time" "connectrpc.com/connect" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/runs/config" +) + +const ( + defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" + + // maxMetadataBodySize bounds the upstream metadata document read; RFC 8414 + // documents are a few KB, so 1 MiB is generous while preventing memory + // blowups from a misconfigured upstream. + maxMetadataBodySize = 1 << 20 + + // metadataCacheTTL is how long a successfully fetched metadata document is + // served from memory. The endpoint is public and the document is effectively + // static, so this avoids an outbound fetch (with retries) per discovery call. + metadataCacheTTL = 5 * time.Minute + + externalFetchTimeout = 10 * time.Second ) // AuthMetadataService implements the AuthMetadataServiceHandler interface. type AuthMetadataService struct { authconnect.UnimplementedAuthMetadataServiceHandler dataplaneDomain string + cfg config.AuthMetadataConfig + httpClient *http.Client + + mu sync.RWMutex + cachedMetadata *auth.GetOAuth2MetadataResponse + cacheExpiry time.Time } -// NewAuthMetadataService creates a new AuthMetadataService instance. -func NewAuthMetadataService(dataplaneDomain string) *AuthMetadataService { +// NewAuthMetadataService creates a new AuthMetadataService instance. When +// cfg.ExternalAuthServerBaseURL is set, GetOAuth2Metadata proxies that server's +// metadata. cfg.FlyteClient is advertised to SDKs via GetPublicClientConfig. +func NewAuthMetadataService(dataplaneDomain string, cfg config.AuthMetadataConfig) *AuthMetadataService { return &AuthMetadataService{ dataplaneDomain: dataplaneDomain, + cfg: cfg, + httpClient: &http.Client{Timeout: externalFetchTimeout}, } } var _ authconnect.AuthMetadataServiceHandler = (*AuthMetadataService)(nil) +// GetPublicClientConfig returns the public (CLI/SDK) OAuth2 client settings. func (s *AuthMetadataService) GetPublicClientConfig( - ctx context.Context, - req *connect.Request[auth.GetPublicClientConfigRequest], + _ context.Context, + _ *connect.Request[auth.GetPublicClientConfigRequest], ) (*connect.Response[auth.GetPublicClientConfigResponse], error) { + authMetadataKey := s.cfg.AuthorizationMetadataKey + if authMetadataKey == "" { + authMetadataKey = "authorization" + } return connect.NewResponse(&auth.GetPublicClientConfigResponse{ - DataplaneDomain: s.dataplaneDomain, + ClientId: s.cfg.FlyteClient.ClientID, + RedirectUri: s.cfg.FlyteClient.RedirectURI, + Scopes: s.cfg.FlyteClient.Scopes, + AuthorizationMetadataKey: authMetadataKey, + Audience: s.cfg.FlyteClient.Audience, + DataplaneDomain: s.dataplaneDomain, }), nil } + +// GetOAuth2Metadata proxies the configured external authorization server's +// metadata document (RFC 8414 / OAuth2 Authorization Server Metadata). This lets +// flyte clients that discover auth at this deployment obtain tokens from the +// external IdP (e.g. Okta) directly, so a single token satisfies both this +// deployment and any upstream (ALB) JWT validation keyed to the same issuer. +func (s *AuthMetadataService) GetOAuth2Metadata( + ctx context.Context, + _ *connect.Request[auth.GetOAuth2MetadataRequest], +) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + if s.cfg.ExternalAuthServerBaseURL == "" { + return nil, connect.NewError(connect.CodeUnimplemented, + errors.New("oauth2 metadata is not configured; set runs.authMetadata.externalAuthServerBaseUrl")) + } + + // Serve from cache while fresh; only successful fetches are cached. + s.mu.RLock() + if s.cachedMetadata != nil && time.Now().Before(s.cacheExpiry) { + cached := proto.Clone(s.cachedMetadata).(*auth.GetOAuth2MetadataResponse) + s.mu.RUnlock() + return connect.NewResponse(cached), nil + } + s.mu.RUnlock() + + baseURL, err := url.Parse(s.cfg.ExternalAuthServerBaseURL) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("invalid external auth server base URL %q: %w", s.cfg.ExternalAuthServerBaseURL, err)) + } + if baseURL.Scheme == "" || baseURL.Host == "" { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("external auth server base URL must be absolute (include scheme and host): %q", s.cfg.ExternalAuthServerBaseURL)) + } + + // Issuer URLs conventionally do not end with a '/', but metadata URLs are + // relative to them. Add a trailing '/' so ResolveReference behaves intuitively. + baseURL.Path = strings.TrimSuffix(baseURL.Path, "/") + "/" + + metadataPath := s.cfg.ExternalMetadataURL + if metadataPath == "" { + metadataPath = defaultOAuth2MetadataPath + } + relURL, err := url.Parse(metadataPath) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("invalid external metadata path %q: %w", metadataPath, err)) + } + // MetadataURL is expected to be a relative path resolved against BaseURL. + // Reject absolute, scheme-relative, or root-relative paths so BaseURL cannot be bypassed. + if relURL.IsAbs() || relURL.Host != "" || strings.HasPrefix(relURL.Path, "/") { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("external metadata path must be relative to externalAuthServerBaseUrl (no leading '/'), got %q", metadataPath)) + } + externalMetadataURL := baseURL.ResolveReference(relURL) + + retryAttempts := s.cfg.RetryAttempts + if retryAttempts <= 0 { + retryAttempts = 5 + } + retryDelay := s.cfg.RetryDelay.Duration + if retryDelay <= 0 { + retryDelay = time.Second + } + + response, err := sendAndRetryHTTPRequest(ctx, s.httpClient, externalMetadataURL.String(), retryAttempts, retryDelay) + if err != nil { + // Preserve pre-classified codes (e.g. Internal for non-retryable 4xx) + // instead of blanket-mapping everything to Unavailable. + var connectErr *connect.Error + switch { + case errors.As(err, &connectErr): + return nil, err + case errors.Is(err, context.Canceled): + return nil, connect.NewError(connect.CodeCanceled, err) + case errors.Is(err, context.DeadlineExceeded): + return nil, connect.NewError(connect.CodeDeadlineExceeded, err) + default: + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) + } + } + defer func() { _ = response.Body.Close() }() + + raw, err := io.ReadAll(io.LimitReader(response.Body, maxMetadataBodySize+1)) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read OAuth2 metadata response: %w", err)) + } + if len(raw) > maxMetadataBodySize { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("OAuth2 metadata response exceeds %d bytes", maxMetadataBodySize)) + } + + resp := &auth.GetOAuth2MetadataResponse{} + if err := unmarshalResp(response, raw, resp); err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) + } + + s.mu.Lock() + s.cachedMetadata = proto.Clone(resp).(*auth.GetOAuth2MetadataResponse) + s.cacheExpiry = time.Now().Add(metadataCacheTTL) + s.mu.Unlock() + + return connect.NewResponse(resp), nil +} + +// OAuth2MetadataHTTPHandler serves the OAuth2 authorization-server metadata +// document at the RFC 8414 well-known path. OAuth2/OIDC discovery clients +// (flyte-sdk) fetch this path directly rather than the Connect RPC. +func OAuth2MetadataHTTPHandler(svc *AuthMetadataService) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.Header().Set("Allow", http.MethodGet) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + resp, err := svc.GetOAuth2Metadata(r.Context(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + if err != nil { + http.Error(w, err.Error(), connectCodeToHTTPStatus(err)) + return + } + // Marshal with proto3 JSON (camelCase) to match what flyteadmin and + // other flyte clients expect from this endpoint. + body, marshalErr := protojson.Marshal(resp.Msg) + if marshalErr != nil { + http.Error(w, marshalErr.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + if _, writeErr := w.Write(body); writeErr != nil { + logger.Warnf(r.Context(), "failed to write oauth2 metadata response: %v", writeErr) + } + }) +} + +func connectCodeToHTTPStatus(err error) int { + var connectErr *connect.Error + if !errors.As(err, &connectErr) { + return http.StatusInternalServerError + } + switch connectErr.Code() { + case connect.CodeUnimplemented: + return http.StatusNotImplemented + case connect.CodeInvalidArgument: + return http.StatusBadRequest + case connect.CodeUnavailable: + return http.StatusServiceUnavailable + default: + return http.StatusInternalServerError + } +} + +// unmarshalResp unmarshals a JSON response body into a protobuf message. It uses +// protojson.Unmarshal which accepts both the camelCase form used by proto3 JSON +// serialization and the snake_case form matching proto field names. This matters +// because external authorization servers (including flyteadmin) emit camelCase +// keys while the Go proto struct tags are snake_case. +func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { + // DiscardUnknown: real authorization servers (e.g. Okta) return many metadata + // fields beyond those modelled here (introspection_endpoint, claims_supported, + // …); without this, unmarshal fails on the first unknown field. + if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(body, v); err == nil { + return nil + } else { + ct := r.Header.Get("Content-Type") + mediaType, _, parseErr := mime.ParseMediaType(ct) + if parseErr == nil && mediaType == "application/json" { + return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %w", err) + } + return fmt.Errorf("expected Content-Type = application/json, got %q: %w", ct, err) + } +} + +// sendAndRetryHTTPRequest fetches the given URL with retry logic. +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { + var lastErr error + var lastResp *http.Response + for i := 0; i < retryAttempts; i++ { + if i > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) + continue + } + + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + return resp, nil + } + + // Don't retry on 4xx: these are usually configuration/request errors. + // Classify as Internal (a misconfiguration of this service), not + // Unavailable, so clients see an accurate status. + if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { + _ = resp.Body.Close() + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("non-retryable status code %d from %s", resp.StatusCode, targetURL)) + } + + _ = resp.Body.Close() + lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) + lastResp = resp + logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) + } + + if lastResp != nil { + return nil, fmt.Errorf("failed to get oauth metadata with status code %v: %w", lastResp.StatusCode, lastErr) + } + return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) +} diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go new file mode 100644 index 0000000000..e9453a542d --- /dev/null +++ b/runs/service/auth_metadata_service_test.go @@ -0,0 +1,201 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + stdlibconfig "github.com/flyteorg/flyte/v2/flytestdlib/config" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/config" +) + +func TestGetOAuth2Metadata_NotConfigured(t *testing.T) { + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err)) +} + +func TestGetOAuth2Metadata_External(t *testing.T) { + // Fake external IdP returning RFC 8414 snake_case metadata with extra fields + // (as Okta does) to exercise DiscardUnknown. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/oauth2/default/.well-known/oauth-authorization-server", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "issuer": "https://idp.example.com/oauth2/default", + "authorization_endpoint": "https://idp.example.com/oauth2/default/v1/authorize", + "token_endpoint": "https://idp.example.com/oauth2/default/v1/token", + "jwks_uri": "https://idp.example.com/oauth2/default/v1/keys", + "introspection_endpoint": "https://idp.example.com/oauth2/default/v1/introspect", + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_supported": ["iss", "sub", "aud"] + }`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL + "/oauth2/default", + }) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com/oauth2/default", resp.Msg.Issuer) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/authorize", resp.Msg.AuthorizationEndpoint) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/token", resp.Msg.TokenEndpoint) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/keys", resp.Msg.JwksUri) +} + +func TestGetOAuth2Metadata_ExternalCustomMetadataURL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + ExternalMetadataURL: "custom/metadata", + }) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) +} + +func TestGetOAuth2Metadata_ExternalUnavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + RetryAttempts: 1, + RetryDelay: stdlibconfig.Duration{Duration: time.Millisecond}, + }) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) +} + +func TestOAuth2MetadataHTTPHandler(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com/oauth2/default","token_endpoint":"https://idp.example.com/oauth2/default/v1/token"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL + "/oauth2/default"}) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Header().Get("Content-Type"), "application/json") + + var doc map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &doc)) + // proto3 JSON marshals to camelCase + assert.Equal(t, "https://idp.example.com/oauth2/default", doc["issuer"]) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/token", doc["tokenEndpoint"]) +} + +func TestOAuth2MetadataHTTPHandler_NotConfigured(t *testing.T) { + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotImplemented, rec.Code) +} + +func TestGetPublicClientConfig(t *testing.T) { + svc := NewAuthMetadataService("dataplane.example.com", config.AuthMetadataConfig{ + AuthorizationMetadataKey: "flyte-authorization", + FlyteClient: config.FlyteClientConfig{ + ClientID: "flytectl", + RedirectURI: "http://localhost:53593/callback", + Scopes: []string{"offline_access", "profile"}, + Audience: "https://api.example.com", + }, + }) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + assert.Equal(t, "flytectl", resp.Msg.ClientId) + assert.Equal(t, "http://localhost:53593/callback", resp.Msg.RedirectUri) + assert.Equal(t, []string{"offline_access", "profile"}, resp.Msg.Scopes) + assert.Equal(t, "https://api.example.com", resp.Msg.Audience) + assert.Equal(t, "flyte-authorization", resp.Msg.AuthorizationMetadataKey) + assert.Equal(t, "dataplane.example.com", resp.Msg.DataplaneDomain) +} + +func TestGetPublicClientConfig_DefaultAuthMetadataKey(t *testing.T) { + // Empty AuthorizationMetadataKey defaults to the standard "authorization" + // header (which upstream JWT validators like ALB inspect). + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + assert.Equal(t, "authorization", resp.Msg.AuthorizationMetadataKey) + assert.Empty(t, resp.Msg.ClientId) +} + +func TestGetOAuth2Metadata_CachesSuccessfulFetch(t *testing.T) { + hits := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL}) + for i := 0; i < 3; i++ { + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) + } + assert.Equal(t, 1, hits, "fresh cache should serve repeat calls without refetching") +} + +func TestGetOAuth2Metadata_BodySizeLimit(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"`)) + big := make([]byte, maxMetadataBodySize+1024) + for i := range big { + big[i] = 'a' + } + _, _ = w.Write(big) + _, _ = w.Write([]byte(`"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL}) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) +} + +func TestGetOAuth2Metadata_Upstream4xxIsInternal(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + RetryAttempts: 1, + RetryDelay: stdlibconfig.Duration{Duration: time.Millisecond}, + }) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err), "non-retryable 4xx should be Internal, not Unavailable") +} diff --git a/runs/setup.go b/runs/setup.go index 665e8b036c..2a8efa70c7 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -113,11 +113,18 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(identityPath, identityHandler) logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) - authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) + authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL, cfg.AuthMetadata) authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(authMetadataPath, authMetadataHandler) logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) + // Serve OAuth2 authorization-server metadata at the RFC 8414 well-known path + // so OAuth2/OIDC discovery clients (flyte-sdk) can find it. When + // runs.authMetadata.externalAuthServerBaseUrl is set, this proxies the + // external IdP's (e.g. Okta) metadata document. + sc.Mux.Handle("/.well-known/oauth-authorization-server", service.OAuth2MetadataHTTPHandler(authMetadataSvc)) + logger.Infof(ctx, "Mounted OAuth2 metadata at /.well-known/oauth-authorization-server") + triggerSvc := service.NewTriggerService(repo) triggerPath, triggerHandler := triggerconnect.NewTriggerServiceHandler(triggerSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(triggerPath, triggerHandler)