diff --git a/go.mod b/go.mod index bbe83556..44513f5a 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,11 @@ require ( github.com/aws/aws-lambda-go v1.54.0 github.com/aws/aws-sdk-go-v2 v1.41.7 github.com/aws/aws-sdk-go-v2/config v1.32.17 + github.com/aws/aws-sdk-go-v2/credentials v1.19.16 github.com/aws/aws-sdk-go-v2/service/s3 v1.100.1 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.7 github.com/aws/aws-sdk-go-v2/service/ssm v1.68.6 + github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 github.com/aws/smithy-go v1.25.1 github.com/go-kit/log v0.2.1 github.com/gogo/protobuf v1.3.2 @@ -40,7 +42,6 @@ require ( github.com/armon/go-metrics v0.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect @@ -52,7 +53,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 // indirect diff --git a/lambda-promtail.yaml b/lambda-promtail.yaml index 4c1eddf8..7f0cab0a 100644 --- a/lambda-promtail.yaml +++ b/lambda-promtail.yaml @@ -20,6 +20,23 @@ Parameters: Type: String Default: "" NoEcho: true + TenantID: + Description: The Loki tenant ID, sent as the X-Scope-OrgID header. Required when WifAudience is set. + Type: String + Default: "" + WifAudience: + Description: > + The audience of the AWS STS web identity token used for workload identity federation (WIF) with + Grafana Cloud. When set, a short-lived JWT is fetched from AWS STS and sent as + 'Authorization: Bearer :'. Mutually exclusive with Username/Password. + Type: String + Default: "" + WifRoleArn: + Description: > + Optional IAM role ARN to assume before requesting the STS web identity token. Leave empty to use + the Lambda execution role directly. Only used when WifAudience is set. + Type: String + Default: "" KeepStream: Description: Determines whether to keep the CloudWatch Log Stream value as a Loki label when writing logs from lambda-promtail. Type: String @@ -50,6 +67,9 @@ Metadata: - WriteAddress - Username - Password + - TenantID + - WifAudience + - WifRoleArn - Label: default: "Lambda function configuration" Parameters: @@ -63,6 +83,10 @@ Metadata: - KeepStream - ExtraLabels +Conditions: + UseWif: !Not [!Equals [!Ref WifAudience, ""]] + HasWifRole: !Not [!Equals [!Ref WifRoleArn, ""]] + Resources: LambdaPromtailRole: @@ -89,12 +113,35 @@ Resources: - logs:CreateLogStream - logs:PutLogEvents Resource: arn:aws:logs:*:*:* - RoleName: GrafanaLabsCloudWatchLogsIntegration + RoleName: !Sub "GrafanaLabsCloudWatchLogsIntegration-${AWS::StackName}" + + # Grants the STS permissions needed for workload identity federation. Only created when + # WifAudience is set. When WifRoleArn is provided the Lambda role is allowed to assume it; + # otherwise the Lambda role fetches the web identity token directly. + LambdaPromtailWifPolicy: + Type: AWS::IAM::Policy + Condition: UseWif + Properties: + PolicyName: wif-sts + Roles: + - !Ref LambdaPromtailRole + PolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: Allow + Action: sts:GetWebIdentityToken + Resource: "*" + - !If + - HasWifRole + - Effect: Allow + Action: sts:AssumeRole + Resource: !Ref WifRoleArn + - !Ref AWS::NoValue LambdaPromtailFunction: Type: AWS::Lambda::Function Properties: - FunctionName: GrafanaCloudLambdaPromtail + FunctionName: !Sub "GrafanaCloudLambdaPromtail-${AWS::StackName}" Code: S3Bucket: !Ref S3BucketName S3Key: !Ref S3KeyName @@ -109,6 +156,9 @@ Resources: WRITE_ADDRESS: !Ref WriteAddress USERNAME: !Ref Username PASSWORD: !Ref Password + TENANT_ID: !Ref TenantID + WIF_AUDIENCE: !Ref WifAudience + WIF_ROLE_ARN: !Ref WifRoleArn KEEP_STREAM: !Ref KeepStream EXTRA_LABELS: !Ref ExtraLabels diff --git a/pkg/auth_sts.go b/pkg/auth_sts.go new file mode 100644 index 00000000..a5ecb857 --- /dev/null +++ b/pkg/auth_sts.go @@ -0,0 +1,105 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const ( + // stsSigningAlg is the signing algorithm requested from STS for the web identity token. + stsSigningAlg = "ES384" + + // stsTokenRefreshMargin is how long before its expiry a cached token is refreshed. + stsTokenRefreshMargin = 1 * time.Minute +) + +// stsWebIdentityTokenClient is the subset of the STS client used to fetch web identity tokens. +// It is an interface so the option can be unit tested without calling AWS. +type stsWebIdentityTokenClient interface { + GetWebIdentityToken(ctx context.Context, params *sts.GetWebIdentityTokenInput, optFns ...func(*sts.Options)) (*sts.GetWebIdentityTokenOutput, error) +} + +// stsWebIdentityOption fetches a web identity JWT from AWS STS and sets it as a bearer token. +// The audience of the token is wifAudience (e.g. +// https://grafana-dev.com/v1/workload-identities/dev-eu-west-2:7161:alloy-ec2), and the header +// is set to: +// +// Authorization: Bearer : +type stsWebIdentityOption struct { + client stsWebIdentityTokenClient + tenantID string + wifAudience string + + mu sync.Mutex + cachedToken string + expiresAt time.Time +} + +// newSTSWebIdentityOption builds an stsWebIdentityOption. If roleARN is non-empty the option +// assumes that role before requesting the token, mirroring the behaviour of Alloy's gcomawsauth. +func newSTSWebIdentityOption(ctx context.Context, tenantID, wifAudience, roleARN string) (*stsWebIdentityOption, error) { + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + if roleARN != "" { + creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(awsCfg), roleARN) + awsCfg.Credentials = aws.NewCredentialsCache(creds) + } + + return &stsWebIdentityOption{ + client: sts.NewFromConfig(awsCfg), + tenantID: tenantID, + wifAudience: wifAudience, + }, nil +} + +func (o *stsWebIdentityOption) Apply(ctx context.Context, req *http.Request) error { + token, err := o.token(ctx) + if err != nil { + return err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s:%s", o.tenantID, token)) + return nil +} + +// token returns a cached web identity token, fetching a fresh one from STS when none is cached +// or the cached one is close to expiry. +func (o *stsWebIdentityOption) token(ctx context.Context) (string, error) { + o.mu.Lock() + defer o.mu.Unlock() + + if o.cachedToken != "" && time.Now().Before(o.expiresAt.Add(-stsTokenRefreshMargin)) { + return o.cachedToken, nil + } + + alg := stsSigningAlg + output, err := o.client.GetWebIdentityToken(ctx, &sts.GetWebIdentityTokenInput{ + Audience: []string{o.wifAudience}, + SigningAlgorithm: &alg, + }) + if err != nil { + return "", fmt.Errorf("failed to get JWT from AWS STS: %w", err) + } + if output.WebIdentityToken == nil { + return "", fmt.Errorf("AWS STS returned an empty web identity token") + } + + o.cachedToken = *output.WebIdentityToken + if output.Expiration != nil { + o.expiresAt = *output.Expiration + } else { + o.expiresAt = time.Time{} + } + + return o.cachedToken, nil +} diff --git a/pkg/main.go b/pkg/main.go index a76e590c..778eb239 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -33,17 +33,18 @@ const ( ) var ( - writeAddress *url.URL - username, password, extraLabelsRaw, dropLabelsRaw, tenantID, bearerToken string - keepStream bool - batchSize int - pipelineTimeout time.Duration - s3Clients map[string]*s3.Client - extraLabels model.LabelSet - dropLabels []model.LabelName - skipTLSVerify bool - printLogLine bool - relabelConfigs []*relabel.Config + writeAddress *url.URL + extraLabelsRaw, dropLabelsRaw, tenantID string + authOptions []AuthOption + keepStream bool + batchSize int + pipelineTimeout time.Duration + s3Clients map[string]*s3.Client + extraLabels model.LabelSet + dropLabels []model.LabelName + skipTLSVerify bool + printLogLine bool + relabelConfigs []*relabel.Config ) func setupArguments(ctx context.Context, secretFetcher secretFetcher) { @@ -72,11 +73,11 @@ func setupArguments(ctx context.Context, secretFetcher secretFetcher) { panic(err) } - username, err = loadSensitiveEnv(ctx, secretFetcher, "USERNAME") + username, err := loadSensitiveEnv(ctx, secretFetcher, "USERNAME") if err != nil { panic(err) } - password, err = loadSensitiveEnv(ctx, secretFetcher, "PASSWORD") + password, err := loadSensitiveEnv(ctx, secretFetcher, "PASSWORD") if err != nil { panic(err) } @@ -85,7 +86,7 @@ func setupArguments(ctx context.Context, secretFetcher secretFetcher) { panic("both username and password must be set if either one is set") } - bearerToken, err = loadSensitiveEnv(ctx, secretFetcher, "BEARER_TOKEN") + bearerToken, err := loadSensitiveEnv(ctx, secretFetcher, "BEARER_TOKEN") if err != nil { panic(err) } @@ -94,14 +95,45 @@ func setupArguments(ctx context.Context, secretFetcher secretFetcher) { panic("both username and bearerToken are not allowed") } + tenantID = os.Getenv("TENANT_ID") + + // Workload identity federation: when WIF_AUDIENCE is set we fetch a short-lived + // web identity JWT from AWS STS and send it as `Authorization: Bearer :`. + wifAudience := os.Getenv("WIF_AUDIENCE") + wifRoleARN := os.Getenv("WIF_ROLE_ARN") + if wifAudience != "" { + if username != "" || bearerToken != "" { + panic("WIF_AUDIENCE cannot be combined with username/password or bearer token auth") + } + if tenantID == "" { + panic("TENANT_ID must be set when WIF_AUDIENCE is used") + } + } + + authOptions = nil + if tenantID != "" { + authOptions = append(authOptions, tenantIDOption{tenantID: tenantID}) + } + if username != "" && password != "" { + authOptions = append(authOptions, basicAuthOption{username: username, password: password}) + } + if bearerToken != "" { + authOptions = append(authOptions, bearerTokenOption{token: bearerToken}) + } + if wifAudience != "" { + stsOption, err := newSTSWebIdentityOption(ctx, tenantID, wifAudience, wifRoleARN) + if err != nil { + panic(err) + } + authOptions = append(authOptions, stsOption) + } + skipTLS := os.Getenv("SKIP_TLS_VERIFY") // Anything other than case-insensitive 'true' is treated as 'false'. if strings.EqualFold(skipTLS, "true") { skipTLSVerify = true } - tenantID = os.Getenv("TENANT_ID") - keep := os.Getenv("KEEP_STREAM") // Anything other than case-insensitive 'true' is treated as 'false'. if strings.EqualFold(keep, "true") { @@ -279,6 +311,7 @@ func handler(ctx context.Context, ev map[string]interface{}) error { timeout: timeout, skipTLSVerify: skipTLSVerify, }, + auth: authOptions, }, log) lokiStageConfigs, err := ParsePipelineConfigs(os.Getenv("LOKI_STAGE_CONFIGS"), *log, metrics) diff --git a/pkg/promtail.go b/pkg/promtail.go index 122ff87a..c65094be 100644 --- a/pkg/promtail.go +++ b/pkg/promtail.go @@ -207,16 +207,10 @@ func (c *promtailClient) send(ctx context.Context, buf []byte) (int, error) { req.Header.Set("Content-Type", contentType) req.Header.Set("User-Agent", userAgent) - if tenantID != "" { - req.Header.Set("X-Scope-OrgID", tenantID) - } - - if username != "" && password != "" { - req.SetBasicAuth(username, password) - } - - if bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+bearerToken) + for _, opt := range c.config.auth { + if err := opt.Apply(ctx, req); err != nil { + return -1, err + } } resp, err := c.http.Do(req.WithContext(ctx)) diff --git a/pkg/promtail_client.go b/pkg/promtail_client.go index 3de20335..e70e0ebf 100644 --- a/pkg/promtail_client.go +++ b/pkg/promtail_client.go @@ -14,6 +14,43 @@ type Client interface { sendToPromtail(ctx context.Context, b *batch) error } +// AuthOption mutates an outgoing request to add authentication information. +// Options are applied in the order they are passed to the client. +type AuthOption interface { + Apply(ctx context.Context, req *http.Request) error +} + +// basicAuthOption sets HTTP basic auth credentials on the request. +type basicAuthOption struct { + username string + password string +} + +func (o basicAuthOption) Apply(_ context.Context, req *http.Request) error { + req.SetBasicAuth(o.username, o.password) + return nil +} + +// bearerTokenOption sets a static bearer token Authorization header on the request. +type bearerTokenOption struct { + token string +} + +func (o bearerTokenOption) Apply(_ context.Context, req *http.Request) error { + req.Header.Set("Authorization", "Bearer "+o.token) + return nil +} + +// tenantIDOption sets the X-Scope-OrgID header used by Loki for multi-tenancy. +type tenantIDOption struct { + tenantID string +} + +func (o tenantIDOption) Apply(_ context.Context, req *http.Request) error { + req.Header.Set("X-Scope-OrgID", o.tenantID) + return nil +} + // Implements Client type promtailClient struct { config *promtailClientConfig @@ -24,6 +61,7 @@ type promtailClient struct { type promtailClientConfig struct { backoff *backoff.Config http *httpClientConfig + auth []AuthOption } type httpClientConfig struct { diff --git a/pkg/promtail_client_test.go b/pkg/promtail_client_test.go new file mode 100644 index 00000000..11751b76 --- /dev/null +++ b/pkg/promtail_client_test.go @@ -0,0 +1,104 @@ +package main + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestRequest(t *testing.T) *http.Request { + t.Helper() + req, err := http.NewRequest("POST", "https://example.com/loki/api/v1/push", nil) + require.NoError(t, err) + return req +} + +func Test_basicAuthOption(t *testing.T) { + req := newTestRequest(t) + require.NoError(t, basicAuthOption{username: "user", password: "pass"}.Apply(context.Background(), req)) + + user, pass, ok := req.BasicAuth() + assert.True(t, ok) + assert.Equal(t, "user", user) + assert.Equal(t, "pass", pass) +} + +func Test_bearerTokenOption(t *testing.T) { + req := newTestRequest(t) + require.NoError(t, bearerTokenOption{token: "tok"}.Apply(context.Background(), req)) + assert.Equal(t, "Bearer tok", req.Header.Get("Authorization")) +} + +func Test_tenantIDOption(t *testing.T) { + req := newTestRequest(t) + require.NoError(t, tenantIDOption{tenantID: "tenant-1"}.Apply(context.Background(), req)) + assert.Equal(t, "tenant-1", req.Header.Get("X-Scope-OrgID")) +} + +// fakeSTSClient is a stub implementation of stsWebIdentityTokenClient. +type fakeSTSClient struct { + calls int + token string + expiration *time.Time + err error + gotInput *sts.GetWebIdentityTokenInput +} + +func (f *fakeSTSClient) GetWebIdentityToken(_ context.Context, params *sts.GetWebIdentityTokenInput, _ ...func(*sts.Options)) (*sts.GetWebIdentityTokenOutput, error) { + f.calls++ + f.gotInput = params + if f.err != nil { + return nil, f.err + } + tok := f.token + return &sts.GetWebIdentityTokenOutput{WebIdentityToken: &tok, Expiration: f.expiration}, nil +} + +func Test_stsWebIdentityOption_Apply(t *testing.T) { + exp := time.Now().Add(time.Hour) + fake := &fakeSTSClient{token: "jwt-123", expiration: &exp} + opt := &stsWebIdentityOption{ + client: fake, + tenantID: "tenant-1", + wifAudience: "https://grafana-dev.com/v1/workload-identities/dev-eu-west-2:7161:alloy-ec2", + } + + req := newTestRequest(t) + require.NoError(t, opt.Apply(context.Background(), req)) + + // Header has the form `Bearer :`. + assert.Equal(t, "Bearer tenant-1:jwt-123", req.Header.Get("Authorization")) + // The audience sent to STS is wifAudience, not the tenant ID. + require.NotNil(t, fake.gotInput) + assert.Equal(t, []string{opt.wifAudience}, fake.gotInput.Audience) + require.NotNil(t, fake.gotInput.SigningAlgorithm) + assert.Equal(t, stsSigningAlg, *fake.gotInput.SigningAlgorithm) +} + +func Test_stsWebIdentityOption_CachesToken(t *testing.T) { + exp := time.Now().Add(time.Hour) + fake := &fakeSTSClient{token: "jwt-123", expiration: &exp} + opt := &stsWebIdentityOption{client: fake, tenantID: "t", wifAudience: "aud"} + + require.NoError(t, opt.Apply(context.Background(), newTestRequest(t))) + require.NoError(t, opt.Apply(context.Background(), newTestRequest(t))) + + assert.Equal(t, 1, fake.calls, "token should be fetched once and then cached") +} + +func Test_stsWebIdentityOption_RefreshesNearExpiry(t *testing.T) { + // Token already within the refresh margin, so each Apply must re-fetch. + exp := time.Now().Add(stsTokenRefreshMargin / 2) + fake := &fakeSTSClient{token: "jwt-123", expiration: &exp} + opt := &stsWebIdentityOption{client: fake, tenantID: "t", wifAudience: "aud"} + + require.NoError(t, opt.Apply(context.Background(), newTestRequest(t))) + require.NoError(t, opt.Apply(context.Background(), newTestRequest(t))) + + assert.Equal(t, 2, fake.calls, "near-expiry token should be refreshed on each use") +}