Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/8272-aws-iam-authentication-strategy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: Added
description: Added aws_iam authentication strategy for SaaS connectors, supporting AWS Signature V4 signing with static credentials and STS AssumeRole with automatic credential caching
pr: 8272
labels: []
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_saas_schema(self) -> Type[SaaSSchema]:
json_schema_extra=extra,
),
)
if connector_param.default_value
if connector_param.default_value or connector_param.optional
else (
param_type,
FieldInfo(
Expand Down
1 change: 1 addition & 0 deletions src/fides/api/schemas/saas/saas_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class ConnectorParam(BaseModel):
multiselect: Optional[bool] = False
description: Optional[str] = None
sensitive: Optional[bool] = False
optional: bool = False
type: Optional[str] = None
allowed_values: Optional[List[str]] = None
# type="endpoint" marks this param as a URL endpoint/domain param.
Expand Down
26 changes: 26 additions & 0 deletions src/fides/api/schemas/saas/strategy_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,29 @@ class GoogleCloudServiceAccountConfiguration(StrategyConfiguration):
"'https://www.googleapis.com/auth/devstorage.read_write' for Cloud Storage, "
),
)


class AWSIAMAuthenticationConfiguration(StrategyConfiguration):
"""
Configuration for AWS IAM (Signature V4) authentication.

Signs HTTP requests using AWS credentials so they can be sent to
IAM-protected endpoints such as API Gateway with IAM authorization.
Supports both static credentials and STS AssumeRole.
"""

region: Optional[str] = Field(
default=None,
description=(
"AWS region for signing requests (e.g. 'us-east-1'). "
"If not specified, the region is resolved from the connector secrets "
"('aws_region') or inferred from the API Gateway endpoint hostname."
),
)
service: str = Field(
default="execute-api",
description=(
"The AWS service name used for Signature V4 signing. "
"Defaults to 'execute-api' for API Gateway."
),
)
1 change: 1 addition & 0 deletions src/fides/api/service/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fides.api.service.authentication import (
authentication_strategy_api_key,
authentication_strategy_aws_iam,
authentication_strategy_basic,
authentication_strategy_bearer,
authentication_strategy_google_cloud_service_account,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, NoReturn, Optional
from urllib.parse import urlparse

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from botocore.exceptions import ClientError, NoCredentialsError
from loguru import logger
from requests import PreparedRequest
from sqlalchemy.orm import Session

from fides.api.common_exceptions import FidesopsException
from fides.api.models.connectionconfig import ConnectionConfig
from fides.api.schemas.saas.strategy_configuration import (
AWSIAMAuthenticationConfiguration,
StrategyConfiguration,
)
from fides.api.service.authentication.authentication_strategy import (
AuthenticationStrategy,
)
from fides.api.util.logger import Pii

TOKEN_REFRESH_BUFFER_SECONDS = 300


class AWSIAMAuthenticationStrategy(AuthenticationStrategy):
"""
Authenticates HTTP requests using AWS IAM (Signature V4).

Supports two modes:
- AssumeRole: Customer provides an IAM Role ARN. Fides assumes the role
via STS to get temporary credentials, then signs requests with SigV4.
- Static keys: Customer provides AWS access key ID and secret access key
directly.

Designed for authenticating against AWS API Gateway endpoints protected
by IAM authorization.
"""

name = "aws_iam"
configuration_model = AWSIAMAuthenticationConfiguration

def __init__(self, configuration: AWSIAMAuthenticationConfiguration):
self.aws_region = configuration.region
self.service = configuration.service

def add_authentication(
self, request: PreparedRequest, connection_config: ConnectionConfig
) -> PreparedRequest:
credentials = self._get_credentials(connection_config)
region = self._resolve_region(request.url, connection_config)

aws_request = AWSRequest(
method=request.method,
url=request.url,
headers=dict(request.headers) if request.headers else {},
data=request.body or "",
)

SigV4Auth(credentials, self.service, region).add_auth(aws_request)

request.headers.update(dict(aws_request.headers))
return request

def _get_credentials(self, connection_config: ConnectionConfig) -> Credentials:
secrets = connection_config.secrets
if not secrets:
raise FidesopsException(
"Secrets are not configured for this connector. "
"AWS IAM authentication requires either an assume_role_arn "
"or aws_access_key_id and aws_secret_access_key."
)

assume_role_arn = secrets.get("aws_assume_role_arn")
if assume_role_arn:
return self._get_assumed_role_credentials(secrets, connection_config)

access_key_id = secrets.get("aws_access_key_id")
secret_access_key = secrets.get("aws_secret_access_key")
if not access_key_id or not secret_access_key:
raise FidesopsException(
"AWS IAM authentication requires either 'aws_assume_role_arn' "
"or both 'aws_access_key_id' and 'aws_secret_access_key'."
)
session_token = secrets.get("aws_session_token")
return Credentials(access_key_id, secret_access_key, session_token)

def _get_assumed_role_credentials(
self,
secrets: Dict[str, Any],
connection_config: ConnectionConfig,
) -> Credentials:
cached_key = secrets.get("aws_iam_access_key_id")
cached_secret = secrets.get("aws_iam_secret_access_key")
cached_token = secrets.get("aws_iam_session_token")
cached_expiry = secrets.get("aws_iam_credentials_expire_at")

if cached_key and cached_secret and cached_token and cached_expiry:
if not self._is_close_to_expiration(cached_expiry):
return Credentials(cached_key, cached_secret, cached_token)
elif any([cached_key, cached_secret, cached_token, cached_expiry]):
logger.debug(
"Partial cached credentials found for {}; refreshing.",
connection_config.key,
)

return self._refresh_assumed_role_credentials(secrets, connection_config)

def _refresh_assumed_role_credentials(
self,
secrets: Dict[str, Any],
connection_config: ConnectionConfig,
) -> Credentials:
# Imported lazily to avoid a hard dependency on boto3 at module load time.
import boto3

assume_role_arn = secrets["aws_assume_role_arn"]

logger.info(
"Assuming AWS IAM role for {}",
connection_config.key,
)

try:
access_key_id = secrets.get("aws_access_key_id")
secret_access_key = secrets.get("aws_secret_access_key")

if access_key_id and secret_access_key:
session = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=secrets.get("aws_session_token"),
)
else:
session = boto3.Session()

sts_client = session.client("sts")
response = sts_client.assume_role(
RoleArn=assume_role_arn,
RoleSessionName="FidesSaaSConnectorSession",
)

temp_creds = response["Credentials"]
access_key = temp_creds["AccessKeyId"]
secret_key = temp_creds["SecretAccessKey"]
session_token = temp_creds["SessionToken"]
expiration = temp_creds["Expiration"]

expires_at = int(expiration.timestamp())
self._store_credentials(
connection_config, access_key, secret_key, session_token, expires_at
)

logger.info(
"Successfully assumed AWS IAM role for {}",
connection_config.key,
)

return Credentials(access_key, secret_key, session_token)

except (ClientError, NoCredentialsError) as exc:
self._handle_credential_error(exc, connection_config)
raise # unreachable; _handle_credential_error is NoReturn

def _resolve_region(
self, url: Optional[str], connection_config: ConnectionConfig
) -> str:
if self.aws_region:
return self.aws_region

secrets = connection_config.secrets or {}
region_from_secrets = secrets.get("aws_region")
if region_from_secrets:
return region_from_secrets

if url:
parsed = urlparse(url)
hostname = parsed.hostname or ""
parts = hostname.split(".")
if len(parts) >= 4 and parts[-2] == "amazonaws" and parts[-1] == "com":
return parts[-3]

logger.warning(
"Could not infer AWS region from URL or secrets for connector {}; "
"defaulting to us-east-1. Set aws_region in the connector secrets or "
"authentication configuration to avoid this.",
connection_config.key,
)
return "us-east-1"

def _is_close_to_expiration(self, expires_at: int) -> bool:
buffer_time = datetime.now(timezone.utc) + timedelta(
seconds=TOKEN_REFRESH_BUFFER_SECONDS
)
return expires_at < buffer_time.timestamp()

def _store_credentials(
self,
connection_config: ConnectionConfig,
access_key_id: str,
secret_access_key: str,
session_token: str,
expires_at: int,
) -> None:
db: Optional[Session] = Session.object_session(connection_config)
if db is None:
logger.warning(
"Unable to cache AWS IAM credentials for {} - no database session available",
connection_config.key,
)
return

updated_secrets = {
**(connection_config.secrets or {}),
"aws_iam_access_key_id": access_key_id,
"aws_iam_secret_access_key": secret_access_key,
"aws_iam_session_token": session_token,
"aws_iam_credentials_expire_at": expires_at,
}
connection_config.update(db, data={"secrets": updated_secrets})
logger.debug(
"Cached AWS IAM credentials for {} (expires at {})",
connection_config.key,
datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
)

def _handle_credential_error(
self, exc: Exception, connection_config: ConnectionConfig
) -> NoReturn:
error_msg = str(exc)
logger.error(
"Error assuming AWS IAM role for {}: {}",
connection_config.key,
Pii(error_msg),
)

if isinstance(exc, NoCredentialsError):
user_message = (
"No base AWS credentials found to authenticate the STS AssumeRole call. "
"Providing an IAM Role ARN alone is not sufficient. AWS requires credentials "
"to call sts:AssumeRole. Either provide aws_access_key_id and "
"aws_secret_access_key alongside the ARN, or ensure the Fides environment "
"has ambient AWS credentials configured (e.g. instance profile, environment "
"variables, or ~/.aws/credentials)."
)
elif isinstance(exc, ClientError):
error_code = exc.response.get("Error", {}).get("Code", "")
if error_code == "AccessDenied":
user_message = (
"Access denied when assuming the IAM role. Verify that "
"the role's trust policy allows Fides to assume it and "
"that the provided credentials have sts:AssumeRole permission."
)
elif error_code in ("MalformedPolicyDocument", "PackedPolicyTooLarge"):
user_message = f"IAM role configuration error: {error_code}. Check the role ARN and trust policy."
else:
user_message = f"AWS STS error ({error_code}): {error_msg}"
else:
user_message = f"Failed to assume AWS IAM role: {error_msg}"

raise FidesopsException(user_message) from exc

@staticmethod
def get_configuration_model() -> StrategyConfiguration:
return AWSIAMAuthenticationConfiguration # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from fides.api.service.authentication.authentication_strategy_oauth2_client_credentials import (
OAuth2ClientCredentialsAuthenticationStrategy,
)
from fides.api.service.authentication.authentication_strategy_aws_iam import (
AWSIAMAuthenticationStrategy,
)
from fides.api.service.authentication.authentication_strategy_query_param import (
QueryParamAuthenticationStrategy,
)
Expand All @@ -40,6 +43,7 @@ class SupportedAuthenticationStrategies(Enum):
oauth2_authorization_code = OAuth2AuthorizationCodeAuthenticationStrategy
oauth2_client_credentials = OAuth2ClientCredentialsAuthenticationStrategy
google_cloud_service_account = GoogleCloudServiceAccountAuthenticationStrategy
aws_iam = AWSIAMAuthenticationStrategy

@classmethod
def __contains__(cls, item: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from fides.api.schemas.saas.saas_config import (
ConnectorParam,
ExternalDatasetReference,
SaaSConfig,
)
from fides.config.security_settings import DomainValidationMode
Expand Down Expand Up @@ -44,23 +43,24 @@ def test_missing_fields(self, saas_config: SaaSConfig):
with pytest.raises(ValidationError) as exc:
schema.model_validate(config)

required_fields = [
connector_param.name
for connector_param in (
saas_config.connector_params + saas_config.external_references
)
if isinstance(
connector_param, ExternalDatasetReference
) # external refs are required
or not connector_param.default_value
]

errors = exc._excinfo[1].errors()
assert (
errors[0]["msg"]
== "Value error, custom_schema must be supplied all of: [username, api_key, api_version, page_size, account_types, customer_id]."
)

def test_optional_param_not_required(self, saas_config: SaaSConfig):
saas_config.connector_params = [
ConnectorParam(name="required_key"),
ConnectorParam(name="optional_key", optional=True),
]
saas_config.external_references = []
schema = SaaSSchemaFactory(saas_config).get_saas_schema()
# optional_key absent — should not raise
schema.model_validate({"required_key": "value"})
# optional_key present — should also work
schema.model_validate({"required_key": "value", "optional_key": "val"})

def test_extra_fields(
self, saas_config: SaaSConfig, saas_example_secrets: Dict[str, Any]
):
Expand Down
Loading
Loading