From 3288fb67e15db406af9569ba8d8dbbbbefd40d0f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 11 Jun 2026 15:12:41 -0700 Subject: [PATCH 01/17] feat(runs): serve external OAuth2 authorization-server metadata Implement GetOAuth2Metadata in the runs AuthMetadataService by proxying an external authorization server's (e.g. Okta) metadata document, and mount it at the RFC 8414 /.well-known/oauth-authorization-server path. Adds runs config (authMetadata.externalAuthServerBaseUrl, ...) and tests. External-fetch logic adapted from #6998. Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- runs/config/config.go | 28 +++ runs/service/auth_metadata_service.go | 207 ++++++++++++++++++++- runs/service/auth_metadata_service_test.go | 116 ++++++++++++ runs/setup.go | 14 +- 4 files changed, 362 insertions(+), 3 deletions(-) create mode 100644 runs/service/auth_metadata_service_test.go diff --git a/runs/config/config.go b/runs/config/config.go index c887d9df55..71c51e28f1 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -62,6 +62,34 @@ type Config struct { // TriggerScheduler configures the cron-based trigger scheduler worker. TriggerScheduler TriggerSchedulerConfig `json:"triggerScheduler"` + + // AuthMetadata configures the OAuth2 authorization-server metadata endpoint + // (the GetOAuth2Metadata RPC and /.well-known/oauth-authorization-server). + AuthMetadata AuthMetadataConfig `json:"authMetadata"` +} + +// AuthMetadataConfig controls how the runs service serves OAuth2 authorization +// server metadata. When ExternalAuthServerBaseURL is set, the service proxies +// the external authorization server's metadata document (e.g. Okta) so that +// clients discovering auth at this deployment are pointed at the external IdP +// and obtain externally-issued tokens. When empty, the endpoint is not served. +type AuthMetadataConfig struct { + // ExternalAuthServerBaseURL is the base URL of the external OAuth2 + // authorization server to proxy metadata from + // (e.g. "https://signin.example.com/oauth2/default"). Empty disables the + // endpoint (GetOAuth2Metadata returns Unimplemented). + ExternalAuthServerBaseURL string `json:"externalAuthServerBaseUrl" pflag:",Base URL of the external OAuth2 authorization server to proxy metadata from"` + + // ExternalMetadataURL optionally overrides the metadata path resolved + // against ExternalAuthServerBaseURL. Defaults to + // ".well-known/oauth-authorization-server". + ExternalMetadataURL string `json:"externalMetadataUrl" pflag:",Override for the external metadata path"` + + // RetryAttempts is how many times to try fetching external metadata (default 5). + RetryAttempts int `json:"retryAttempts" pflag:",Attempts to fetch external metadata"` + + // RetryDelay is the delay between fetch attempts (default 1s). + RetryDelay time.Duration `json:"retryDelay" pflag:",Delay between external metadata fetch attempts"` } // ServerConfig holds HTTP server configuration diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index f27fb450e1..c83af980b7 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -2,23 +2,54 @@ package service import ( "context" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + "time" "connectrpc.com/connect" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" ) +const defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" + +// ExternalAuthServerConfig configures proxying of an external OAuth2 +// authorization server's metadata document (e.g. Okta). +type ExternalAuthServerConfig struct { + // BaseURL is the external authorization server's base URL. When empty, the + // OAuth2 metadata endpoint is not served. + BaseURL string + // MetadataURL overrides the metadata path resolved against BaseURL. Defaults + // to ".well-known/oauth-authorization-server". + MetadataURL string + // RetryAttempts is the number of fetch attempts (default 5). + RetryAttempts int + // RetryDelay is the delay between fetch attempts (default 1s). + RetryDelay time.Duration +} + // AuthMetadataService implements the AuthMetadataServiceHandler interface. type AuthMetadataService struct { authconnect.UnimplementedAuthMetadataServiceHandler dataplaneDomain string + external ExternalAuthServerConfig } -// NewAuthMetadataService creates a new AuthMetadataService instance. -func NewAuthMetadataService(dataplaneDomain string) *AuthMetadataService { +// NewAuthMetadataService creates a new AuthMetadataService instance. When +// external.BaseURL is set, GetOAuth2Metadata proxies that server's metadata. +func NewAuthMetadataService(dataplaneDomain string, external ExternalAuthServerConfig) *AuthMetadataService { return &AuthMetadataService{ dataplaneDomain: dataplaneDomain, + external: external, } } @@ -32,3 +63,175 @@ func (s *AuthMetadataService) GetPublicClientConfig( DataplaneDomain: s.dataplaneDomain, }), nil } + +// GetOAuth2Metadata proxies the configured external authorization server's +// metadata document (RFC 8414 / OAuth2 Authorization Server Metadata). This lets +// flyte clients that discover auth at this deployment obtain tokens from the +// external IdP (e.g. Okta) directly, so a single token satisfies both this +// deployment and any upstream (ALB) JWT validation keyed to the same issuer. +// +// The external-fetch logic is adapted from flyteorg/flyte#6998. +func (s *AuthMetadataService) GetOAuth2Metadata( + ctx context.Context, + _ *connect.Request[auth.GetOAuth2MetadataRequest], +) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + if s.external.BaseURL == "" { + return nil, connect.NewError(connect.CodeUnimplemented, + errors.New("oauth2 metadata is not configured; set runs.authMetadata.externalAuthServerBaseUrl")) + } + + baseURL, err := url.Parse(s.external.BaseURL) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("invalid external auth server base URL %q: %w", s.external.BaseURL, err)) + } + + // Issuer URLs conventionally do not end with a '/', but metadata URLs are + // relative to them. Add a trailing '/' so ResolveReference behaves intuitively. + baseURL.Path = strings.TrimSuffix(baseURL.Path, "/") + "/" + + metadataPath := s.external.MetadataURL + if metadataPath == "" { + metadataPath = defaultOAuth2MetadataPath + } + relURL, err := url.Parse(metadataPath) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("invalid external metadata path %q: %w", metadataPath, err)) + } + externalMetadataURL := baseURL.ResolveReference(relURL) + + retryAttempts := s.external.RetryAttempts + if retryAttempts <= 0 { + retryAttempts = 5 + } + retryDelay := s.external.RetryDelay + if retryDelay <= 0 { + retryDelay = time.Second + } + + response, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, externalMetadataURL.String(), retryAttempts, retryDelay) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) + } + defer func() { _ = response.Body.Close() }() + + raw, err := io.ReadAll(response.Body) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read OAuth2 metadata response: %w", err)) + } + + resp := &auth.GetOAuth2MetadataResponse{} + if err := unmarshalResp(response, raw, resp); err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) + } + + return connect.NewResponse(resp), nil +} + +// OAuth2MetadataHTTPHandler serves the OAuth2 authorization-server metadata +// document at the RFC 8414 well-known path. OAuth2/OIDC discovery clients +// (flytectl, pyflyte) fetch this path directly rather than the Connect RPC. +func OAuth2MetadataHTTPHandler(svc *AuthMetadataService) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp, err := svc.GetOAuth2Metadata(r.Context(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + if err != nil { + http.Error(w, err.Error(), connectCodeToHTTPStatus(err)) + return + } + // Marshal with proto3 JSON (camelCase) to match what flyteadmin and + // other flyte clients expect from this endpoint. + body, marshalErr := protojson.Marshal(resp.Msg) + if marshalErr != nil { + http.Error(w, marshalErr.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + if _, writeErr := w.Write(body); writeErr != nil { + logger.Warnf(r.Context(), "failed to write oauth2 metadata response: %v", writeErr) + } + }) +} + +func connectCodeToHTTPStatus(err error) int { + var connectErr *connect.Error + if !errors.As(err, &connectErr) { + return http.StatusInternalServerError + } + switch connectErr.Code() { + case connect.CodeUnimplemented: + return http.StatusNotImplemented + case connect.CodeInvalidArgument: + return http.StatusBadRequest + case connect.CodeUnavailable: + return http.StatusBadGateway + default: + return http.StatusInternalServerError + } +} + +// unmarshalResp unmarshals a JSON response body into a protobuf message. It uses +// protojson.Unmarshal which accepts both the camelCase form used by proto3 JSON +// serialization and the snake_case form matching proto field names. This matters +// because external authorization servers (including flyteadmin) emit camelCase +// keys while the Go proto struct tags are snake_case. +// +// Adapted from flyteorg/flyte#6998. +func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { + // DiscardUnknown: real authorization servers (e.g. Okta) return many metadata + // fields beyond those modelled here (introspection_endpoint, claims_supported, + // …); without this, unmarshal fails on the first unknown field. + if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(body, v); err == nil { + return nil + } else { + ct := r.Header.Get("Content-Type") + mediaType, _, parseErr := mime.ParseMediaType(ct) + if parseErr == nil && mediaType == "application/json" { + return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %w", err) + } + return fmt.Errorf("expected Content-Type = application/json, got %q: %w", ct, err) + } +} + +// sendAndRetryHTTPRequest fetches the given URL with retry logic. +// +// Adapted from flyteorg/flyte#6998. +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { + var lastErr error + var lastResp *http.Response + for i := 0; i < retryAttempts; i++ { + if i > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) + continue + } + + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + return resp, nil + } + + _ = resp.Body.Close() + lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) + lastResp = resp + logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) + } + + if lastResp != nil { + return nil, fmt.Errorf("failed to get oauth metadata with status code %v: %w", lastResp.StatusCode, lastErr) + } + return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) +} diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go new file mode 100644 index 0000000000..204bd6d623 --- /dev/null +++ b/runs/service/auth_metadata_service_test.go @@ -0,0 +1,116 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +func TestGetOAuth2Metadata_NotConfigured(t *testing.T) { + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err)) +} + +func TestGetOAuth2Metadata_External(t *testing.T) { + // Fake external IdP returning RFC 8414 snake_case metadata with extra fields + // (as Okta does) to exercise DiscardUnknown. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/oauth2/default/.well-known/oauth-authorization-server", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "issuer": "https://idp.example.com/oauth2/default", + "authorization_endpoint": "https://idp.example.com/oauth2/default/v1/authorize", + "token_endpoint": "https://idp.example.com/oauth2/default/v1/token", + "jwks_uri": "https://idp.example.com/oauth2/default/v1/keys", + "introspection_endpoint": "https://idp.example.com/oauth2/default/v1/introspect", + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_supported": ["iss", "sub", "aud"] + }`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ + BaseURL: srv.URL + "/oauth2/default", + }) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com/oauth2/default", resp.Msg.Issuer) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/authorize", resp.Msg.AuthorizationEndpoint) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/token", resp.Msg.TokenEndpoint) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/keys", resp.Msg.JwksUri) +} + +func TestGetOAuth2Metadata_ExternalCustomMetadataURL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ + BaseURL: srv.URL, + MetadataURL: "custom/metadata", + }) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) +} + +func TestGetOAuth2Metadata_ExternalUnavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ + BaseURL: srv.URL, + RetryAttempts: 1, + RetryDelay: time.Millisecond, + }) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) +} + +func TestOAuth2MetadataHTTPHandler(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com/oauth2/default","token_endpoint":"https://idp.example.com/oauth2/default/v1/token"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{BaseURL: srv.URL + "/oauth2/default"}) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Header().Get("Content-Type"), "application/json") + + var doc map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &doc)) + // proto3 JSON marshals to camelCase + assert.Equal(t, "https://idp.example.com/oauth2/default", doc["issuer"]) + assert.Equal(t, "https://idp.example.com/oauth2/default/v1/token", doc["tokenEndpoint"]) +} + +func TestOAuth2MetadataHTTPHandler_NotConfigured(t *testing.T) { + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotImplemented, rec.Code) +} diff --git a/runs/setup.go b/runs/setup.go index a2f53c25b3..9072844b6b 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -112,11 +112,23 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(identityPath, identityHandler) logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) - authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) + authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL, service.ExternalAuthServerConfig{ + BaseURL: cfg.AuthMetadata.ExternalAuthServerBaseURL, + MetadataURL: cfg.AuthMetadata.ExternalMetadataURL, + RetryAttempts: cfg.AuthMetadata.RetryAttempts, + RetryDelay: cfg.AuthMetadata.RetryDelay, + }) authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(authMetadataPath, authMetadataHandler) logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) + // Serve OAuth2 authorization-server metadata at the RFC 8414 well-known path + // so OAuth2/OIDC discovery clients (flytectl, pyflyte) can find it. When + // runs.authMetadata.externalAuthServerBaseUrl is set, this proxies the + // external IdP's (e.g. Okta) metadata document. + sc.Mux.Handle("/.well-known/oauth-authorization-server", service.OAuth2MetadataHTTPHandler(authMetadataSvc)) + logger.Infof(ctx, "Mounted OAuth2 metadata at /.well-known/oauth-authorization-server") + triggerSvc := service.NewTriggerService(repo) triggerPath, triggerHandler := triggerconnect.NewTriggerServiceHandler(triggerSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(triggerPath, triggerHandler) From 4dd146fc0ae036d2a95364cf0e9a215d078e4522 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 11 Jun 2026 16:55:53 -0700 Subject: [PATCH 02/17] feat(runs): implement GetPublicClientConfig from config Port flyteadmin's OAuth2MetadataProvider.GetPublicClientConfig (auth/authzserver/metadata_provider.go): serve the public CLI/SDK OAuth2 client settings (clientId, redirectUri, scopes, audience, authorizationMetadataKey) from runs.authMetadata config instead of the previous stub that only returned dataplaneDomain. Empty authorizationMetadataKey defaults to the standard "authorization" header so upstream JWT validators (e.g. ALB) see the token. Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- runs/config/config.go | 26 ++++++++++++++ runs/service/auth_metadata_service.go | 37 ++++++++++++++++++-- runs/service/auth_metadata_service_test.go | 40 ++++++++++++++++++---- runs/setup.go | 6 ++++ 4 files changed, 101 insertions(+), 8 deletions(-) diff --git a/runs/config/config.go b/runs/config/config.go index 71c51e28f1..219d1ac92e 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -90,6 +90,32 @@ type AuthMetadataConfig struct { // RetryDelay is the delay between fetch attempts (default 1s). RetryDelay time.Duration `json:"retryDelay" pflag:",Delay between external metadata fetch attempts"` + + // AuthorizationMetadataKey is the header/metadata key clients should place + // tokens in, returned by GetPublicClientConfig (default "authorization"). + AuthorizationMetadataKey string `json:"authorizationMetadataKey" pflag:",Header key clients should use for tokens"` + + // FlyteClient is the public (CLI/SDK) OAuth2 client configuration returned + // by GetPublicClientConfig. + FlyteClient FlyteClientConfig `json:"flyteClient"` +} + +// FlyteClientConfig mirrors flyteadmin's appAuth.thirdPartyConfig.flyteClient: +// the public OAuth2 client (flytectl/pyflyte) settings advertised to SDKs via +// GetPublicClientConfig. +type FlyteClientConfig struct { + // ClientID is the public client id used by CLI/SDK login flows. + ClientID string `json:"clientId" pflag:",Public OAuth2 client id advertised to SDKs"` + + // RedirectURI is the callback the public client listens on during login. + RedirectURI string `json:"redirectUri" pflag:",Redirect URI for the public client login flow"` + + // Scopes are the OAuth2 scopes the public client should request. + Scopes []string `json:"scopes" pflag:",Scopes the public client should request"` + + // Audience is the intended audience for requested tokens (sent when the IdP + // requires it, e.g. Auth0/Okta custom authorization servers). + Audience string `json:"audience" pflag:",Audience for requested tokens"` } // ServerConfig holds HTTP server configuration diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index c83af980b7..0a1979cc98 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -37,30 +37,63 @@ type ExternalAuthServerConfig struct { RetryDelay time.Duration } +// PublicClientConfig is the public (CLI/SDK) OAuth2 client configuration +// advertised via GetPublicClientConfig. Mirrors flyteadmin's +// appAuth.thirdPartyConfig.flyteClient + grpcAuthorizationHeader. +type PublicClientConfig struct { + // ClientID is the public client id used by CLI/SDK login flows. + ClientID string + // RedirectURI is the callback the public client listens on during login. + RedirectURI string + // Scopes are the OAuth2 scopes the public client should request. + Scopes []string + // Audience is the intended audience for requested tokens. + Audience string + // AuthorizationMetadataKey is the header/metadata key clients should place + // tokens in (default "authorization"). + AuthorizationMetadataKey string +} + // AuthMetadataService implements the AuthMetadataServiceHandler interface. type AuthMetadataService struct { authconnect.UnimplementedAuthMetadataServiceHandler dataplaneDomain string external ExternalAuthServerConfig + publicClient PublicClientConfig } // NewAuthMetadataService creates a new AuthMetadataService instance. When // external.BaseURL is set, GetOAuth2Metadata proxies that server's metadata. -func NewAuthMetadataService(dataplaneDomain string, external ExternalAuthServerConfig) *AuthMetadataService { +// publicClient is advertised to SDKs via GetPublicClientConfig. +func NewAuthMetadataService(dataplaneDomain string, external ExternalAuthServerConfig, publicClient PublicClientConfig) *AuthMetadataService { return &AuthMetadataService{ dataplaneDomain: dataplaneDomain, external: external, + publicClient: publicClient, } } var _ authconnect.AuthMetadataServiceHandler = (*AuthMetadataService)(nil) +// GetPublicClientConfig returns the public (CLI/SDK) OAuth2 client settings. +// Mirrors flyteadmin's OAuth2MetadataProvider.GetPublicClientConfig +// (auth/authzserver/metadata_provider.go), which serves +// appAuth.thirdPartyConfig.flyteClient + grpcAuthorizationHeader. func (s *AuthMetadataService) GetPublicClientConfig( ctx context.Context, req *connect.Request[auth.GetPublicClientConfigRequest], ) (*connect.Response[auth.GetPublicClientConfigResponse], error) { + authMetadataKey := s.publicClient.AuthorizationMetadataKey + if authMetadataKey == "" { + authMetadataKey = "authorization" + } return connect.NewResponse(&auth.GetPublicClientConfigResponse{ - DataplaneDomain: s.dataplaneDomain, + ClientId: s.publicClient.ClientID, + RedirectUri: s.publicClient.RedirectURI, + Scopes: s.publicClient.Scopes, + AuthorizationMetadataKey: authMetadataKey, + Audience: s.publicClient.Audience, + DataplaneDomain: s.dataplaneDomain, }), nil } diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go index 204bd6d623..f733a99dc7 100644 --- a/runs/service/auth_metadata_service_test.go +++ b/runs/service/auth_metadata_service_test.go @@ -16,7 +16,7 @@ import ( ) func TestGetOAuth2Metadata_NotConfigured(t *testing.T) { - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}) + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err)) @@ -42,7 +42,7 @@ func TestGetOAuth2Metadata_External(t *testing.T) { svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ BaseURL: srv.URL + "/oauth2/default", - }) + }, PublicClientConfig{}) resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.NoError(t, err) assert.Equal(t, "https://idp.example.com/oauth2/default", resp.Msg.Issuer) @@ -62,7 +62,7 @@ func TestGetOAuth2Metadata_ExternalCustomMetadataURL(t *testing.T) { svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ BaseURL: srv.URL, MetadataURL: "custom/metadata", - }) + }, PublicClientConfig{}) resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.NoError(t, err) assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) @@ -78,7 +78,7 @@ func TestGetOAuth2Metadata_ExternalUnavailable(t *testing.T) { BaseURL: srv.URL, RetryAttempts: 1, RetryDelay: time.Millisecond, - }) + }, PublicClientConfig{}) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) @@ -91,7 +91,7 @@ func TestOAuth2MetadataHTTPHandler(t *testing.T) { })) defer srv.Close() - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{BaseURL: srv.URL + "/oauth2/default"}) + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{BaseURL: srv.URL + "/oauth2/default"}, PublicClientConfig{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) @@ -108,9 +108,37 @@ func TestOAuth2MetadataHTTPHandler(t *testing.T) { } func TestOAuth2MetadataHTTPHandler_NotConfigured(t *testing.T) { - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}) + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) assert.Equal(t, http.StatusNotImplemented, rec.Code) } + +func TestGetPublicClientConfig(t *testing.T) { + svc := NewAuthMetadataService("dataplane.example.com", ExternalAuthServerConfig{}, PublicClientConfig{ + ClientID: "flytectl", + RedirectURI: "http://localhost:53593/callback", + Scopes: []string{"offline_access", "profile"}, + Audience: "https://api.example.com", + AuthorizationMetadataKey: "flyte-authorization", + }) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + assert.Equal(t, "flytectl", resp.Msg.ClientId) + assert.Equal(t, "http://localhost:53593/callback", resp.Msg.RedirectUri) + assert.Equal(t, []string{"offline_access", "profile"}, resp.Msg.Scopes) + assert.Equal(t, "https://api.example.com", resp.Msg.Audience) + assert.Equal(t, "flyte-authorization", resp.Msg.AuthorizationMetadataKey) + assert.Equal(t, "dataplane.example.com", resp.Msg.DataplaneDomain) +} + +func TestGetPublicClientConfig_DefaultAuthMetadataKey(t *testing.T) { + // Empty AuthorizationMetadataKey defaults to the standard "authorization" + // header (which upstream JWT validators like ALB inspect). + svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + assert.Equal(t, "authorization", resp.Msg.AuthorizationMetadataKey) + assert.Empty(t, resp.Msg.ClientId) +} diff --git a/runs/setup.go b/runs/setup.go index 9072844b6b..f268589606 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -117,6 +117,12 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { MetadataURL: cfg.AuthMetadata.ExternalMetadataURL, RetryAttempts: cfg.AuthMetadata.RetryAttempts, RetryDelay: cfg.AuthMetadata.RetryDelay, + }, service.PublicClientConfig{ + ClientID: cfg.AuthMetadata.FlyteClient.ClientID, + RedirectURI: cfg.AuthMetadata.FlyteClient.RedirectURI, + Scopes: cfg.AuthMetadata.FlyteClient.Scopes, + Audience: cfg.AuthMetadata.FlyteClient.Audience, + AuthorizationMetadataKey: cfg.AuthMetadata.AuthorizationMetadataKey, }) authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(authMetadataPath, authMetadataHandler) From fe87589f0ff76150d7eb73b06f16e8280721bb82 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 12:41:42 -0700 Subject: [PATCH 03/17] nit Signed-off-by: Kevin Su --- runs/setup.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/setup.go b/runs/setup.go index f268589606..3a8e15e0a4 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -129,7 +129,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) // Serve OAuth2 authorization-server metadata at the RFC 8414 well-known path - // so OAuth2/OIDC discovery clients (flytectl, pyflyte) can find it. When + // so OAuth2/OIDC discovery clients (flyte-sdk) can find it. When // runs.authMetadata.externalAuthServerBaseUrl is set, this proxies the // external IdP's (e.g. Okta) metadata document. sc.Mux.Handle("/.well-known/oauth-authorization-server", service.OAuth2MetadataHTTPHandler(authMetadataSvc)) From 06c25e3ce0617edfccabf269bacb3ce1fed7beab Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 12:42:11 -0700 Subject: [PATCH 04/17] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/config/config.go | 4 ++-- runs/service/auth_metadata_service.go | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/runs/config/config.go b/runs/config/config.go index 219d1ac92e..728495d416 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -72,8 +72,8 @@ type Config struct { // server metadata. When ExternalAuthServerBaseURL is set, the service proxies // the external authorization server's metadata document (e.g. Okta) so that // clients discovering auth at this deployment are pointed at the external IdP -// and obtain externally-issued tokens. When empty, the endpoint is not served. -type AuthMetadataConfig struct { +// and obtain externally-issued tokens. When empty, GetOAuth2Metadata returns +// Unimplemented (HTTP 501 for the well-known handler). // ExternalAuthServerBaseURL is the base URL of the external OAuth2 // authorization server to proxy metadata from // (e.g. "https://signin.example.com/oauth2/default"). Empty disables the diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 0a1979cc98..a9f4ffbda6 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -25,8 +25,8 @@ const defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" // ExternalAuthServerConfig configures proxying of an external OAuth2 // authorization server's metadata document (e.g. Okta). type ExternalAuthServerConfig struct { - // BaseURL is the external authorization server's base URL. When empty, the - // OAuth2 metadata endpoint is not served. + // BaseURL is the external authorization server's base URL. When empty, + // GetOAuth2Metadata returns Unimplemented (HTTP 501 for the well-known handler). BaseURL string // MetadataURL overrides the metadata path resolved against BaseURL. Defaults // to ".well-known/oauth-authorization-server". @@ -118,6 +118,10 @@ func (s *AuthMetadataService) GetOAuth2Metadata( return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("invalid external auth server base URL %q: %w", s.external.BaseURL, err)) } + if baseURL.Scheme == "" || baseURL.Host == "" { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("external auth server base URL must be absolute (include scheme and host): %q", s.external.BaseURL)) + } // Issuer URLs conventionally do not end with a '/', but metadata URLs are // relative to them. Add a trailing '/' so ResolveReference behaves intuitively. @@ -143,7 +147,8 @@ func (s *AuthMetadataService) GetOAuth2Metadata( retryDelay = time.Second } - response, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, externalMetadataURL.String(), retryAttempts, retryDelay) + client := &http.Client{Timeout: 10 * time.Second} + response, err := sendAndRetryHTTPRequest(ctx, client, externalMetadataURL.String(), retryAttempts, retryDelay) if err != nil { return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) } @@ -167,6 +172,11 @@ func (s *AuthMetadataService) GetOAuth2Metadata( // (flytectl, pyflyte) fetch this path directly rather than the Connect RPC. func OAuth2MetadataHTTPHandler(svc *AuthMetadataService) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.Header().Set("Allow", http.MethodGet) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } resp, err := svc.GetOAuth2Metadata(r.Context(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) if err != nil { http.Error(w, err.Error(), connectCodeToHTTPStatus(err)) From 77ba23b3a59f39eba4e05640c32cff9c90a7a85d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 12:52:08 -0700 Subject: [PATCH 05/17] fix(runs): restore AuthMetadataConfig struct declaration lost in review suggestion Signed-off-by: Kevin Su --- runs/config/config.go | 1 + 1 file changed, 1 insertion(+) diff --git a/runs/config/config.go b/runs/config/config.go index 728495d416..27506a0cb9 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -74,6 +74,7 @@ type Config struct { // clients discovering auth at this deployment are pointed at the external IdP // and obtain externally-issued tokens. When empty, GetOAuth2Metadata returns // Unimplemented (HTTP 501 for the well-known handler). +type AuthMetadataConfig struct { // ExternalAuthServerBaseURL is the base URL of the external OAuth2 // authorization server to proxy metadata from // (e.g. "https://signin.example.com/oauth2/default"). Empty disables the From 33c56c02fa1bc08cbb4fab9b365a0520d5984edb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 12:53:23 -0700 Subject: [PATCH 06/17] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index a9f4ffbda6..1350d07fec 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -136,6 +136,12 @@ func (s *AuthMetadataService) GetOAuth2Metadata( return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("invalid external metadata path %q: %w", metadataPath, err)) } + // MetadataURL is expected to be a relative path resolved against BaseURL. + // Reject absolute or scheme-relative URLs so BaseURL cannot be bypassed. + if relURL.IsAbs() || relURL.Host != "" { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("external metadata path must be relative to externalAuthServerBaseUrl, got %q", metadataPath)) + } externalMetadataURL := baseURL.ResolveReference(relURL) retryAttempts := s.external.RetryAttempts From c8a3f943a4a97b923c5491bfabdd38229939faa7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:00:18 -0700 Subject: [PATCH 07/17] refactor(runs): inject config.AuthMetadataConfig directly into AuthMetadataService Drop the service-local ExternalAuthServerConfig/PublicClientConfig mirror structs and pass cfg.AuthMetadata wholesale, removing the field-by-field copy in setup.go. Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 72 ++++++---------------- runs/service/auth_metadata_service_test.go | 45 +++++++------- runs/setup.go | 13 +--- 3 files changed, 43 insertions(+), 87 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 1350d07fec..a57f903180 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -18,81 +18,45 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/runs/config" ) const defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" -// ExternalAuthServerConfig configures proxying of an external OAuth2 -// authorization server's metadata document (e.g. Okta). -type ExternalAuthServerConfig struct { - // BaseURL is the external authorization server's base URL. When empty, - // GetOAuth2Metadata returns Unimplemented (HTTP 501 for the well-known handler). - BaseURL string - // MetadataURL overrides the metadata path resolved against BaseURL. Defaults - // to ".well-known/oauth-authorization-server". - MetadataURL string - // RetryAttempts is the number of fetch attempts (default 5). - RetryAttempts int - // RetryDelay is the delay between fetch attempts (default 1s). - RetryDelay time.Duration -} - -// PublicClientConfig is the public (CLI/SDK) OAuth2 client configuration -// advertised via GetPublicClientConfig. Mirrors flyteadmin's -// appAuth.thirdPartyConfig.flyteClient + grpcAuthorizationHeader. -type PublicClientConfig struct { - // ClientID is the public client id used by CLI/SDK login flows. - ClientID string - // RedirectURI is the callback the public client listens on during login. - RedirectURI string - // Scopes are the OAuth2 scopes the public client should request. - Scopes []string - // Audience is the intended audience for requested tokens. - Audience string - // AuthorizationMetadataKey is the header/metadata key clients should place - // tokens in (default "authorization"). - AuthorizationMetadataKey string -} - // AuthMetadataService implements the AuthMetadataServiceHandler interface. type AuthMetadataService struct { authconnect.UnimplementedAuthMetadataServiceHandler dataplaneDomain string - external ExternalAuthServerConfig - publicClient PublicClientConfig + cfg config.AuthMetadataConfig } // NewAuthMetadataService creates a new AuthMetadataService instance. When -// external.BaseURL is set, GetOAuth2Metadata proxies that server's metadata. -// publicClient is advertised to SDKs via GetPublicClientConfig. -func NewAuthMetadataService(dataplaneDomain string, external ExternalAuthServerConfig, publicClient PublicClientConfig) *AuthMetadataService { +// cfg.ExternalAuthServerBaseURL is set, GetOAuth2Metadata proxies that server's +// metadata. cfg.FlyteClient is advertised to SDKs via GetPublicClientConfig. +func NewAuthMetadataService(dataplaneDomain string, cfg config.AuthMetadataConfig) *AuthMetadataService { return &AuthMetadataService{ dataplaneDomain: dataplaneDomain, - external: external, - publicClient: publicClient, + cfg: cfg, } } var _ authconnect.AuthMetadataServiceHandler = (*AuthMetadataService)(nil) // GetPublicClientConfig returns the public (CLI/SDK) OAuth2 client settings. -// Mirrors flyteadmin's OAuth2MetadataProvider.GetPublicClientConfig -// (auth/authzserver/metadata_provider.go), which serves -// appAuth.thirdPartyConfig.flyteClient + grpcAuthorizationHeader. func (s *AuthMetadataService) GetPublicClientConfig( ctx context.Context, req *connect.Request[auth.GetPublicClientConfigRequest], ) (*connect.Response[auth.GetPublicClientConfigResponse], error) { - authMetadataKey := s.publicClient.AuthorizationMetadataKey + authMetadataKey := s.cfg.AuthorizationMetadataKey if authMetadataKey == "" { authMetadataKey = "authorization" } return connect.NewResponse(&auth.GetPublicClientConfigResponse{ - ClientId: s.publicClient.ClientID, - RedirectUri: s.publicClient.RedirectURI, - Scopes: s.publicClient.Scopes, + ClientId: s.cfg.FlyteClient.ClientID, + RedirectUri: s.cfg.FlyteClient.RedirectURI, + Scopes: s.cfg.FlyteClient.Scopes, AuthorizationMetadataKey: authMetadataKey, - Audience: s.publicClient.Audience, + Audience: s.cfg.FlyteClient.Audience, DataplaneDomain: s.dataplaneDomain, }), nil } @@ -108,26 +72,26 @@ func (s *AuthMetadataService) GetOAuth2Metadata( ctx context.Context, _ *connect.Request[auth.GetOAuth2MetadataRequest], ) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { - if s.external.BaseURL == "" { + if s.cfg.ExternalAuthServerBaseURL == "" { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("oauth2 metadata is not configured; set runs.authMetadata.externalAuthServerBaseUrl")) } - baseURL, err := url.Parse(s.external.BaseURL) + baseURL, err := url.Parse(s.cfg.ExternalAuthServerBaseURL) if err != nil { return nil, connect.NewError(connect.CodeInternal, - fmt.Errorf("invalid external auth server base URL %q: %w", s.external.BaseURL, err)) + fmt.Errorf("invalid external auth server base URL %q: %w", s.cfg.ExternalAuthServerBaseURL, err)) } if baseURL.Scheme == "" || baseURL.Host == "" { return nil, connect.NewError(connect.CodeInternal, - fmt.Errorf("external auth server base URL must be absolute (include scheme and host): %q", s.external.BaseURL)) + fmt.Errorf("external auth server base URL must be absolute (include scheme and host): %q", s.cfg.ExternalAuthServerBaseURL)) } // Issuer URLs conventionally do not end with a '/', but metadata URLs are // relative to them. Add a trailing '/' so ResolveReference behaves intuitively. baseURL.Path = strings.TrimSuffix(baseURL.Path, "/") + "/" - metadataPath := s.external.MetadataURL + metadataPath := s.cfg.ExternalMetadataURL if metadataPath == "" { metadataPath = defaultOAuth2MetadataPath } @@ -144,11 +108,11 @@ func (s *AuthMetadataService) GetOAuth2Metadata( } externalMetadataURL := baseURL.ResolveReference(relURL) - retryAttempts := s.external.RetryAttempts + retryAttempts := s.cfg.RetryAttempts if retryAttempts <= 0 { retryAttempts = 5 } - retryDelay := s.external.RetryDelay + retryDelay := s.cfg.RetryDelay if retryDelay <= 0 { retryDelay = time.Second } diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go index f733a99dc7..49d2023782 100644 --- a/runs/service/auth_metadata_service_test.go +++ b/runs/service/auth_metadata_service_test.go @@ -13,10 +13,11 @@ import ( "github.com/stretchr/testify/require" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/config" ) func TestGetOAuth2Metadata_NotConfigured(t *testing.T) { - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) assert.Equal(t, connect.CodeUnimplemented, connect.CodeOf(err)) @@ -40,9 +41,9 @@ func TestGetOAuth2Metadata_External(t *testing.T) { })) defer srv.Close() - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ - BaseURL: srv.URL + "/oauth2/default", - }, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL + "/oauth2/default", + }) resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.NoError(t, err) assert.Equal(t, "https://idp.example.com/oauth2/default", resp.Msg.Issuer) @@ -59,10 +60,10 @@ func TestGetOAuth2Metadata_ExternalCustomMetadataURL(t *testing.T) { })) defer srv.Close() - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ - BaseURL: srv.URL, - MetadataURL: "custom/metadata", - }, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + ExternalMetadataURL: "custom/metadata", + }) resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.NoError(t, err) assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) @@ -74,11 +75,11 @@ func TestGetOAuth2Metadata_ExternalUnavailable(t *testing.T) { })) defer srv.Close() - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{ - BaseURL: srv.URL, - RetryAttempts: 1, - RetryDelay: time.Millisecond, - }, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + RetryAttempts: 1, + RetryDelay: time.Millisecond, + }) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) assert.Equal(t, connect.CodeUnavailable, connect.CodeOf(err)) @@ -91,7 +92,7 @@ func TestOAuth2MetadataHTTPHandler(t *testing.T) { })) defer srv.Close() - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{BaseURL: srv.URL + "/oauth2/default"}, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL + "/oauth2/default"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) @@ -108,7 +109,7 @@ func TestOAuth2MetadataHTTPHandler(t *testing.T) { } func TestOAuth2MetadataHTTPHandler_NotConfigured(t *testing.T) { - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) OAuth2MetadataHTTPHandler(svc).ServeHTTP(rec, req) @@ -116,12 +117,14 @@ func TestOAuth2MetadataHTTPHandler_NotConfigured(t *testing.T) { } func TestGetPublicClientConfig(t *testing.T) { - svc := NewAuthMetadataService("dataplane.example.com", ExternalAuthServerConfig{}, PublicClientConfig{ - ClientID: "flytectl", - RedirectURI: "http://localhost:53593/callback", - Scopes: []string{"offline_access", "profile"}, - Audience: "https://api.example.com", + svc := NewAuthMetadataService("dataplane.example.com", config.AuthMetadataConfig{ AuthorizationMetadataKey: "flyte-authorization", + FlyteClient: config.FlyteClientConfig{ + ClientID: "flytectl", + RedirectURI: "http://localhost:53593/callback", + Scopes: []string{"offline_access", "profile"}, + Audience: "https://api.example.com", + }, }) resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) require.NoError(t, err) @@ -136,7 +139,7 @@ func TestGetPublicClientConfig(t *testing.T) { func TestGetPublicClientConfig_DefaultAuthMetadataKey(t *testing.T) { // Empty AuthorizationMetadataKey defaults to the standard "authorization" // header (which upstream JWT validators like ALB inspect). - svc := NewAuthMetadataService("example.com", ExternalAuthServerConfig{}, PublicClientConfig{}) + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{}) resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) require.NoError(t, err) assert.Equal(t, "authorization", resp.Msg.AuthorizationMetadataKey) diff --git a/runs/setup.go b/runs/setup.go index 3a8e15e0a4..6b6aa3954b 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -112,18 +112,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(identityPath, identityHandler) logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) - authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL, service.ExternalAuthServerConfig{ - BaseURL: cfg.AuthMetadata.ExternalAuthServerBaseURL, - MetadataURL: cfg.AuthMetadata.ExternalMetadataURL, - RetryAttempts: cfg.AuthMetadata.RetryAttempts, - RetryDelay: cfg.AuthMetadata.RetryDelay, - }, service.PublicClientConfig{ - ClientID: cfg.AuthMetadata.FlyteClient.ClientID, - RedirectURI: cfg.AuthMetadata.FlyteClient.RedirectURI, - Scopes: cfg.AuthMetadata.FlyteClient.Scopes, - Audience: cfg.AuthMetadata.FlyteClient.Audience, - AuthorizationMetadataKey: cfg.AuthMetadata.AuthorizationMetadataKey, - }) + authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL, cfg.AuthMetadata) authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc, connect.WithInterceptors(otelInterceptor)) sc.Mux.Handle(authMetadataPath, authMetadataHandler) logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) From cd992d851d84816e72dd8f0751d61a23f4d7d3ed Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:05:44 -0700 Subject: [PATCH 08/17] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index a57f903180..86e9b629c1 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -177,7 +177,7 @@ func connectCodeToHTTPStatus(err error) int { case connect.CodeInvalidArgument: return http.StatusBadRequest case connect.CodeUnavailable: - return http.StatusBadGateway + return http.StatusServiceUnavailable default: return http.StatusInternalServerError } From aefcca8c77950e93652d1c7518f904acfc4fb74e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:17:13 -0700 Subject: [PATCH 09/17] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 86e9b629c1..4f44803925 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -101,10 +101,10 @@ func (s *AuthMetadataService) GetOAuth2Metadata( fmt.Errorf("invalid external metadata path %q: %w", metadataPath, err)) } // MetadataURL is expected to be a relative path resolved against BaseURL. - // Reject absolute or scheme-relative URLs so BaseURL cannot be bypassed. - if relURL.IsAbs() || relURL.Host != "" { + // Reject absolute, scheme-relative, or root-relative paths so BaseURL cannot be bypassed. + if relURL.IsAbs() || relURL.Host != "" || strings.HasPrefix(relURL.Path, "/") { return nil, connect.NewError(connect.CodeInternal, - fmt.Errorf("external metadata path must be relative to externalAuthServerBaseUrl, got %q", metadataPath)) + fmt.Errorf("external metadata path must be relative to externalAuthServerBaseUrl (no leading '/'), got %q", metadataPath)) } externalMetadataURL := baseURL.ResolveReference(relURL) From 5b21be19a48f957abd44ec1f3dc631eb18bd2734 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:17:38 -0700 Subject: [PATCH 10/17] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 4f44803925..1c58e37442 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -237,6 +237,12 @@ func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL return resp, nil } + // Don't retry on 4xx: these are usually configuration/request errors. + if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { + _ = resp.Body.Close() + return nil, fmt.Errorf("non-retryable status code %d from %s", resp.StatusCode, targetURL) + } + _ = resp.Body.Close() lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) lastResp = resp From 84ec9d46ca2ec825ee0089554d7e8369cc79e4e9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:28:52 -0700 Subject: [PATCH 11/17] fix(runs): address review feedback on OAuth2 metadata proxy - Reuse a single http.Client (10s timeout) instead of per-request construction - Bound the upstream metadata read to 1 MiB - Cache successfully fetched metadata for 5 minutes (the endpoint is public and the document is effectively static; avoids an outbound fetch with retries per discovery call) - Classify non-retryable upstream 4xx as Internal and propagate pre-classified connect codes instead of blanket Unavailable - Fix discovery-clients wording in setup.go Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 57 ++++++++++++++++++++-- runs/service/auth_metadata_service_test.go | 53 ++++++++++++++++++++ runs/setup.go | 2 +- 3 files changed, 106 insertions(+), 6 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 1c58e37442..93bb8724e1 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "connectrpc.com/connect" @@ -21,13 +22,32 @@ import ( "github.com/flyteorg/flyte/v2/runs/config" ) -const defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" +const ( + defaultOAuth2MetadataPath = ".well-known/oauth-authorization-server" + + // maxMetadataBodySize bounds the upstream metadata document read; RFC 8414 + // documents are a few KB, so 1 MiB is generous while preventing memory + // blowups from a misconfigured upstream. + maxMetadataBodySize = 1 << 20 + + // metadataCacheTTL is how long a successfully fetched metadata document is + // served from memory. The endpoint is public and the document is effectively + // static, so this avoids an outbound fetch (with retries) per discovery call. + metadataCacheTTL = 5 * time.Minute + + externalFetchTimeout = 10 * time.Second +) // AuthMetadataService implements the AuthMetadataServiceHandler interface. type AuthMetadataService struct { authconnect.UnimplementedAuthMetadataServiceHandler dataplaneDomain string cfg config.AuthMetadataConfig + httpClient *http.Client + + mu sync.Mutex + cachedMetadata *auth.GetOAuth2MetadataResponse + cacheExpiry time.Time } // NewAuthMetadataService creates a new AuthMetadataService instance. When @@ -37,6 +57,7 @@ func NewAuthMetadataService(dataplaneDomain string, cfg config.AuthMetadataConfi return &AuthMetadataService{ dataplaneDomain: dataplaneDomain, cfg: cfg, + httpClient: &http.Client{Timeout: externalFetchTimeout}, } } @@ -77,6 +98,15 @@ func (s *AuthMetadataService) GetOAuth2Metadata( errors.New("oauth2 metadata is not configured; set runs.authMetadata.externalAuthServerBaseUrl")) } + // Serve from cache while fresh; only successful fetches are cached. + s.mu.Lock() + if s.cachedMetadata != nil && time.Now().Before(s.cacheExpiry) { + cached := proto.Clone(s.cachedMetadata).(*auth.GetOAuth2MetadataResponse) + s.mu.Unlock() + return connect.NewResponse(cached), nil + } + s.mu.Unlock() + baseURL, err := url.Parse(s.cfg.ExternalAuthServerBaseURL) if err != nil { return nil, connect.NewError(connect.CodeInternal, @@ -117,23 +147,37 @@ func (s *AuthMetadataService) GetOAuth2Metadata( retryDelay = time.Second } - client := &http.Client{Timeout: 10 * time.Second} - response, err := sendAndRetryHTTPRequest(ctx, client, externalMetadataURL.String(), retryAttempts, retryDelay) + response, err := sendAndRetryHTTPRequest(ctx, s.httpClient, externalMetadataURL.String(), retryAttempts, retryDelay) if err != nil { + // Preserve pre-classified codes (e.g. Internal for non-retryable 4xx) + // instead of blanket-mapping everything to Unavailable. + var connectErr *connect.Error + if errors.As(err, &connectErr) { + return nil, err + } return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) } defer func() { _ = response.Body.Close() }() - raw, err := io.ReadAll(response.Body) + raw, err := io.ReadAll(io.LimitReader(response.Body, maxMetadataBodySize+1)) if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read OAuth2 metadata response: %w", err)) } + if len(raw) > maxMetadataBodySize { + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("OAuth2 metadata response exceeds %d bytes", maxMetadataBodySize)) + } resp := &auth.GetOAuth2MetadataResponse{} if err := unmarshalResp(response, raw, resp); err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) } + s.mu.Lock() + s.cachedMetadata = proto.Clone(resp).(*auth.GetOAuth2MetadataResponse) + s.cacheExpiry = time.Now().Add(metadataCacheTTL) + s.mu.Unlock() + return connect.NewResponse(resp), nil } @@ -238,9 +282,12 @@ func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL } // Don't retry on 4xx: these are usually configuration/request errors. + // Classify as Internal (a misconfiguration of this service), not + // Unavailable, so clients see an accurate status. if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode < http.StatusInternalServerError { _ = resp.Body.Close() - return nil, fmt.Errorf("non-retryable status code %d from %s", resp.StatusCode, targetURL) + return nil, connect.NewError(connect.CodeInternal, + fmt.Errorf("non-retryable status code %d from %s", resp.StatusCode, targetURL)) } _ = resp.Body.Close() diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go index 49d2023782..577e6dfdfc 100644 --- a/runs/service/auth_metadata_service_test.go +++ b/runs/service/auth_metadata_service_test.go @@ -145,3 +145,56 @@ func TestGetPublicClientConfig_DefaultAuthMetadataKey(t *testing.T) { assert.Equal(t, "authorization", resp.Msg.AuthorizationMetadataKey) assert.Empty(t, resp.Msg.ClientId) } + +func TestGetOAuth2Metadata_CachesSuccessfulFetch(t *testing.T) { + hits := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"https://idp.example.com"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL}) + for i := 0; i < 3; i++ { + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + assert.Equal(t, "https://idp.example.com", resp.Msg.Issuer) + } + assert.Equal(t, 1, hits, "fresh cache should serve repeat calls without refetching") +} + +func TestGetOAuth2Metadata_BodySizeLimit(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer":"`)) + big := make([]byte, maxMetadataBodySize+1024) + for i := range big { + big[i] = 'a' + } + _, _ = w.Write(big) + _, _ = w.Write([]byte(`"}`)) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ExternalAuthServerBaseURL: srv.URL}) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) +} + +func TestGetOAuth2Metadata_Upstream4xxIsInternal(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ + ExternalAuthServerBaseURL: srv.URL, + RetryAttempts: 1, + RetryDelay: time.Millisecond, + }) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err), "non-retryable 4xx should be Internal, not Unavailable") +} diff --git a/runs/setup.go b/runs/setup.go index 6b6aa3954b..1e30524c90 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -118,7 +118,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) // Serve OAuth2 authorization-server metadata at the RFC 8414 well-known path - // so OAuth2/OIDC discovery clients (flyte-sdk) can find it. When + // so OAuth2/OIDC discovery clients (flytectl, pyflyte) can find it. When // runs.authMetadata.externalAuthServerBaseUrl is set, this proxies the // external IdP's (e.g. Okta) metadata document. sc.Mux.Handle("/.well-known/oauth-authorization-server", service.OAuth2MetadataHTTPHandler(authMetadataSvc)) From 9cf415f5a5dcd2306209fd5f1418ad07954790e6 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 13:36:13 -0700 Subject: [PATCH 12/17] nit Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 2 +- runs/setup.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 93bb8724e1..d5a8d1d2e1 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -183,7 +183,7 @@ func (s *AuthMetadataService) GetOAuth2Metadata( // OAuth2MetadataHTTPHandler serves the OAuth2 authorization-server metadata // document at the RFC 8414 well-known path. OAuth2/OIDC discovery clients -// (flytectl, pyflyte) fetch this path directly rather than the Connect RPC. +// (flyte-sdk) fetch this path directly rather than the Connect RPC. func OAuth2MetadataHTTPHandler(svc *AuthMetadataService) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { diff --git a/runs/setup.go b/runs/setup.go index 1e30524c90..6b6aa3954b 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -118,7 +118,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) // Serve OAuth2 authorization-server metadata at the RFC 8414 well-known path - // so OAuth2/OIDC discovery clients (flytectl, pyflyte) can find it. When + // so OAuth2/OIDC discovery clients (flyte-sdk) can find it. When // runs.authMetadata.externalAuthServerBaseUrl is set, this proxies the // external IdP's (e.g. Okta) metadata document. sc.Mux.Handle("/.well-known/oauth-authorization-server", service.OAuth2MetadataHTTPHandler(authMetadataSvc)) From 62fce0c0989c86361284221c8f3f7bb42ecbbbe4 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 14:14:21 -0700 Subject: [PATCH 13/17] chore(runs): regenerate pflags for authMetadata config go generate ./runs/config was failing on types the pflags generator does not support, so config_flags.go had gone stale (it also predated the triggerScheduler fields). Unblock generation and regenerate: - RetryDelay: time.Duration -> flytestdlib config.Duration (idiomatic, keeps the flag) - Domains, triggerScheduler.resyncInterval, triggerScheduler.executionQps: pflag:"-" (slice-of-struct / raw duration / float64 are unsupported by the generator; config-file only) Generated flags now include runs.authMetadata.* (incl. flyteClient.*) and the previously missing triggerScheduler.* entries. Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- runs/config/config.go | 14 +- runs/config/config_flags.go | 12 ++ runs/config/config_flags_test.go | 176 ++++++++++++++++++++- runs/service/auth_metadata_service.go | 2 +- runs/service/auth_metadata_service_test.go | 5 +- 5 files changed, 198 insertions(+), 11 deletions(-) diff --git a/runs/config/config.go b/runs/config/config.go index 27506a0cb9..c2b7ab3c53 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -58,7 +58,9 @@ type Config struct { SeedProjects []string `json:"seedProjects" pflag:",Projects to create by default at startup"` // Domains are injected into project responses (not stored per project row). - Domains []DomainConfig `json:"domains"` + // Excluded from pflags (slices of structs are unsupported by the generator); + // configure via the config file only. + Domains []DomainConfig `json:"domains" pflag:"-"` // TriggerScheduler configures the cron-based trigger scheduler worker. TriggerScheduler TriggerSchedulerConfig `json:"triggerScheduler"` @@ -90,7 +92,7 @@ type AuthMetadataConfig struct { RetryAttempts int `json:"retryAttempts" pflag:",Attempts to fetch external metadata"` // RetryDelay is the delay between fetch attempts (default 1s). - RetryDelay time.Duration `json:"retryDelay" pflag:",Delay between external metadata fetch attempts"` + RetryDelay config.Duration `json:"retryDelay" pflag:",Delay between external metadata fetch attempts"` // AuthorizationMetadataKey is the header/metadata key clients should place // tokens in, returned by GetPublicClientConfig (default "authorization"). @@ -137,13 +139,17 @@ type TriggerSchedulerConfig struct { Enabled bool `json:"enabled" pflag:",Enable the trigger scheduler worker"` // ResyncInterval is how often the scheduler re-reads active triggers from the DB. - ResyncInterval time.Duration `json:"resyncInterval" pflag:",How often to resync active triggers from the database"` + // Excluded from pflags (raw time.Duration is unsupported by the generator); + // configure via the config file only. + ResyncInterval time.Duration `json:"resyncInterval" pflag:"-"` // MaxCatchupRunsPerLoop caps how many catchup runs are fired per resync loop. MaxCatchupRunsPerLoop int `json:"maxCatchupRunsPerLoop" pflag:",Maximum catchup runs fired per resync loop"` // ExecutionQPS is the token-bucket rate for CreateRun calls (tokens/second). - ExecutionQPS float64 `json:"executionQps" pflag:",Rate limit for CreateRun calls (requests per second)"` + // Excluded from pflags (float64 is unsupported by the generator); configure + // via the config file only. + ExecutionQPS float64 `json:"executionQps" pflag:"-"` // ExecutionBurst is the token-bucket burst size. ExecutionBurst int `json:"executionBurst" pflag:",Burst size for CreateRun rate limiter"` diff --git a/runs/config/config_flags.go b/runs/config/config_flags.go index 7a5bbc271b..a32e66083d 100755 --- a/runs/config/config_flags.go +++ b/runs/config/config_flags.go @@ -69,5 +69,17 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "actionsServiceUrl"), defaultConfig.ActionsServiceURL, "URL of the actions service") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "storagePrefix"), defaultConfig.StoragePrefix, "Base URI prefix for storing run inputs and outputs") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "seedProjects"), defaultConfig.SeedProjects, "Projects to create by default at startup") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "triggerScheduler.enabled"), defaultConfig.TriggerScheduler.Enabled, "Enable the trigger scheduler worker") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "triggerScheduler.maxCatchupRunsPerLoop"), defaultConfig.TriggerScheduler.MaxCatchupRunsPerLoop, "Maximum catchup runs fired per resync loop") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "triggerScheduler.executionBurst"), defaultConfig.TriggerScheduler.ExecutionBurst, "Burst size for CreateRun rate limiter") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.externalAuthServerBaseUrl"), defaultConfig.AuthMetadata.ExternalAuthServerBaseURL, "Base URL of the external OAuth2 authorization server to proxy metadata from") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.externalMetadataUrl"), defaultConfig.AuthMetadata.ExternalMetadataURL, "Override for the external metadata path") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "authMetadata.retryAttempts"), defaultConfig.AuthMetadata.RetryAttempts, "Attempts to fetch external metadata") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.retryDelay"), defaultConfig.AuthMetadata.RetryDelay.String(), "Delay between external metadata fetch attempts") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.authorizationMetadataKey"), defaultConfig.AuthMetadata.AuthorizationMetadataKey, "Header key clients should use for tokens") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.clientId"), defaultConfig.AuthMetadata.FlyteClient.ClientID, "Public OAuth2 client id advertised to SDKs") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.redirectUri"), defaultConfig.AuthMetadata.FlyteClient.RedirectURI, "Redirect URI for the public client login flow") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.scopes"), defaultConfig.AuthMetadata.FlyteClient.Scopes, "Scopes the public client should request") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authMetadata.flyteClient.audience"), defaultConfig.AuthMetadata.FlyteClient.Audience, "Audience for requested tokens") return cmdFlags } diff --git a/runs/config/config_flags_test.go b/runs/config/config_flags_test.go index c9450dbb1d..936e24a753 100755 --- a/runs/config/config_flags_test.go +++ b/runs/config/config_flags_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - "github.com/mitchellh/mapstructure" + "github.com/go-viper/mapstructure/v2" "github.com/stretchr/testify/assert" ) @@ -354,11 +354,179 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Test_seedProjects", func(t *testing.T) { t.Run("Override", func(t *testing.T) { - testValue := []string{"flytesnacks", "demo"} + testValue := join_Config(defaultConfig.SeedProjects, ",") - cmdFlags.Set("seedProjects", join_Config(testValue, ",")) + cmdFlags.Set("seedProjects", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("seedProjects"); err == nil { - testDecodeRaw_Config(t, vStringSlice, &actual.SeedProjects) + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.SeedProjects) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.enabled", testValue) + if vBool, err := cmdFlags.GetBool("triggerScheduler.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.TriggerScheduler.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.maxCatchupRunsPerLoop", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.maxCatchupRunsPerLoop", testValue) + if vInt, err := cmdFlags.GetInt("triggerScheduler.maxCatchupRunsPerLoop"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.TriggerScheduler.MaxCatchupRunsPerLoop) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_triggerScheduler.executionBurst", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("triggerScheduler.executionBurst", testValue) + if vInt, err := cmdFlags.GetInt("triggerScheduler.executionBurst"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.TriggerScheduler.ExecutionBurst) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.externalAuthServerBaseUrl", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.externalAuthServerBaseUrl", testValue) + if vString, err := cmdFlags.GetString("authMetadata.externalAuthServerBaseUrl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.ExternalAuthServerBaseURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.externalMetadataUrl", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.externalMetadataUrl", testValue) + if vString, err := cmdFlags.GetString("authMetadata.externalMetadataUrl"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.ExternalMetadataURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.retryAttempts", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.retryAttempts", testValue) + if vInt, err := cmdFlags.GetInt("authMetadata.retryAttempts"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AuthMetadata.RetryAttempts) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.retryDelay", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.AuthMetadata.RetryDelay.String() + + cmdFlags.Set("authMetadata.retryDelay", testValue) + if vString, err := cmdFlags.GetString("authMetadata.retryDelay"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.RetryDelay) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.authorizationMetadataKey", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.authorizationMetadataKey", testValue) + if vString, err := cmdFlags.GetString("authMetadata.authorizationMetadataKey"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.AuthorizationMetadataKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.clientId", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.clientId", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.clientId"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.ClientID) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.redirectUri", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.redirectUri", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.redirectUri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.RedirectURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.scopes", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config(defaultConfig.AuthMetadata.FlyteClient.Scopes, ",") + + cmdFlags.Set("authMetadata.flyteClient.scopes", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("authMetadata.flyteClient.scopes"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.AuthMetadata.FlyteClient.Scopes) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_authMetadata.flyteClient.audience", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("authMetadata.flyteClient.audience", testValue) + if vString, err := cmdFlags.GetString("authMetadata.flyteClient.audience"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthMetadata.FlyteClient.Audience) } else { assert.FailNow(t, err.Error()) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index d5a8d1d2e1..d3994ba8f5 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -142,7 +142,7 @@ func (s *AuthMetadataService) GetOAuth2Metadata( if retryAttempts <= 0 { retryAttempts = 5 } - retryDelay := s.cfg.RetryDelay + retryDelay := s.cfg.RetryDelay.Duration if retryDelay <= 0 { retryDelay = time.Second } diff --git a/runs/service/auth_metadata_service_test.go b/runs/service/auth_metadata_service_test.go index 577e6dfdfc..e9453a542d 100644 --- a/runs/service/auth_metadata_service_test.go +++ b/runs/service/auth_metadata_service_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + stdlibconfig "github.com/flyteorg/flyte/v2/flytestdlib/config" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" "github.com/flyteorg/flyte/v2/runs/config" ) @@ -78,7 +79,7 @@ func TestGetOAuth2Metadata_ExternalUnavailable(t *testing.T) { svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ ExternalAuthServerBaseURL: srv.URL, RetryAttempts: 1, - RetryDelay: time.Millisecond, + RetryDelay: stdlibconfig.Duration{Duration: time.Millisecond}, }) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) @@ -192,7 +193,7 @@ func TestGetOAuth2Metadata_Upstream4xxIsInternal(t *testing.T) { svc := NewAuthMetadataService("example.com", config.AuthMetadataConfig{ ExternalAuthServerBaseURL: srv.URL, RetryAttempts: 1, - RetryDelay: time.Millisecond, + RetryDelay: stdlibconfig.Duration{Duration: time.Millisecond}, }) _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) require.Error(t, err) From 40d582a8dd14c690b9004ea1a566ac8b187071f2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 14:27:03 -0700 Subject: [PATCH 14/17] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index d3994ba8f5..5ba5efe442 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -152,10 +152,16 @@ func (s *AuthMetadataService) GetOAuth2Metadata( // Preserve pre-classified codes (e.g. Internal for non-retryable 4xx) // instead of blanket-mapping everything to Unavailable. var connectErr *connect.Error - if errors.As(err, &connectErr) { + switch { + case errors.As(err, &connectErr): return nil, err + case errors.Is(err, context.Canceled): + return nil, connect.NewError(connect.CodeCanceled, err) + case errors.Is(err, context.DeadlineExceeded): + return nil, connect.NewError(connect.CodeDeadlineExceeded, err) + default: + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) } - return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) } defer func() { _ = response.Body.Close() }() From efa8428c61971eda293c6ae1a6762058f7cd01cd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 Jun 2026 14:34:55 -0700 Subject: [PATCH 15/17] nit Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 5ba5efe442..16f6247475 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -87,8 +87,6 @@ func (s *AuthMetadataService) GetPublicClientConfig( // flyte clients that discover auth at this deployment obtain tokens from the // external IdP (e.g. Okta) directly, so a single token satisfies both this // deployment and any upstream (ALB) JWT validation keyed to the same issuer. -// -// The external-fetch logic is adapted from flyteorg/flyte#6998. func (s *AuthMetadataService) GetOAuth2Metadata( ctx context.Context, _ *connect.Request[auth.GetOAuth2MetadataRequest], @@ -238,8 +236,6 @@ func connectCodeToHTTPStatus(err error) int { // serialization and the snake_case form matching proto field names. This matters // because external authorization servers (including flyteadmin) emit camelCase // keys while the Go proto struct tags are snake_case. -// -// Adapted from flyteorg/flyte#6998. func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { // DiscardUnknown: real authorization servers (e.g. Okta) return many metadata // fields beyond those modelled here (introspection_endpoint, claims_supported, @@ -257,8 +253,6 @@ func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { } // sendAndRetryHTTPRequest fetches the given URL with retry logic. -// -// Adapted from flyteorg/flyte#6998. func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { var lastErr error var lastResp *http.Response From 66d99f8987b19fb4b95a85a9e6ad9a5734b92b66 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 16 Jun 2026 18:20:47 -0700 Subject: [PATCH 16/17] nit Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index 16f6247475..cca437a2a2 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -65,8 +65,8 @@ var _ authconnect.AuthMetadataServiceHandler = (*AuthMetadataService)(nil) // GetPublicClientConfig returns the public (CLI/SDK) OAuth2 client settings. func (s *AuthMetadataService) GetPublicClientConfig( - ctx context.Context, - req *connect.Request[auth.GetPublicClientConfigRequest], + _ context.Context, + _ *connect.Request[auth.GetPublicClientConfigRequest], ) (*connect.Response[auth.GetPublicClientConfigResponse], error) { authMetadataKey := s.cfg.AuthorizationMetadataKey if authMetadataKey == "" { From 4a387c17aa5e2fe2df7c80bf5884ad01e7ea3297 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 16 Jun 2026 18:26:01 -0700 Subject: [PATCH 17/17] rwlock Signed-off-by: Kevin Su --- runs/service/auth_metadata_service.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/runs/service/auth_metadata_service.go b/runs/service/auth_metadata_service.go index cca437a2a2..e311fd7054 100644 --- a/runs/service/auth_metadata_service.go +++ b/runs/service/auth_metadata_service.go @@ -45,7 +45,7 @@ type AuthMetadataService struct { cfg config.AuthMetadataConfig httpClient *http.Client - mu sync.Mutex + mu sync.RWMutex cachedMetadata *auth.GetOAuth2MetadataResponse cacheExpiry time.Time } @@ -97,13 +97,13 @@ func (s *AuthMetadataService) GetOAuth2Metadata( } // Serve from cache while fresh; only successful fetches are cached. - s.mu.Lock() + s.mu.RLock() if s.cachedMetadata != nil && time.Now().Before(s.cacheExpiry) { cached := proto.Clone(s.cachedMetadata).(*auth.GetOAuth2MetadataResponse) - s.mu.Unlock() + s.mu.RUnlock() return connect.NewResponse(cached), nil } - s.mu.Unlock() + s.mu.RUnlock() baseURL, err := url.Parse(s.cfg.ExternalAuthServerBaseURL) if err != nil {