diff --git a/CHANGELOG.md b/CHANGELOG.md index 0307c165c7e..da6fe849ccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,19 @@ Changes can also be flagged with a GitHub label for tracking purposes. The URL o - https://github.com/ethyca/fides/labels/high-risk: to indicate that a change is a "high-risk" change that could potentially lead to unanticipated regressions or degradations - https://github.com/ethyca/fides/labels/db-migration: to indicate that a given change includes a DB migration -## [Unreleased](https://github.com/ethyca/fides/compare/2.85.0..main) +## [Unreleased](https://github.com/ethyca/fides/compare/2.85.1..main) + +## [2.85.1](https://github.com/ethyca/fides/compare/2.85.0..2.85.1) + +### Added +- Added SecretProvider abstraction and AWS Secrets Manager provider [#8051](https://github.com/ethyca/fides/pull/8051) +- Add DBCredentialProvider for dynamic database credential resolution via AWS Secrets Manager [#8175](https://github.com/ethyca/fides/pull/8175) +- Added configurable pool_recycle setting for database connections [#8209](https://github.com/ethyca/fides/pull/8209) + +### Changed +- Route all database connections through DBCredentialProvider for dynamic credential resolution [#8176](https://github.com/ethyca/fides/pull/8176) https://github.com/ethyca/fides/labels/high-risk +- Refactored database engines to use SQLAlchemy creator pattern for per-connection credential resolution [#8148](https://github.com/ethyca/fides/pull/8148) +- Changed the label on API client comments to make it more obvious that they are from the API client and not from the user. [#8220](https://github.com/ethyca/fides/pull/8220) ## [2.85.0](https://github.com/ethyca/fides/compare/2.84.3..2.85.0) diff --git a/clients/admin-ui/src/features/privacy-requests/events-and-logs/ActivityTimelineEntry.tsx b/clients/admin-ui/src/features/privacy-requests/events-and-logs/ActivityTimelineEntry.tsx index 28a31836b0a..2f7acfc5f39 100644 --- a/clients/admin-ui/src/features/privacy-requests/events-and-logs/ActivityTimelineEntry.tsx +++ b/clients/admin-ui/src/features/privacy-requests/events-and-logs/ActivityTimelineEntry.tsx @@ -111,7 +111,8 @@ const ActivityTimelineEntry = ({ item }: ActivityTimelineEntryProps) => { [styles["itemButton--polling"]]: isPolling, [styles["itemButton--clickable"]]: isClickable, [styles["itemButton--comment"]]: - type === ActivityTimelineItemTypeEnum.INTERNAL_COMMENT, + type === ActivityTimelineItemTypeEnum.INTERNAL_COMMENT || + type === ActivityTimelineItemTypeEnum.INTERNAL_AUTOMATION_COMMENT, [styles["itemButton--manual-task"]]: type === ActivityTimelineItemTypeEnum.MANUAL_TASK, }), diff --git a/clients/admin-ui/src/features/privacy-requests/events-and-logs/hooks/usePrivacyRequestComments.ts b/clients/admin-ui/src/features/privacy-requests/events-and-logs/hooks/usePrivacyRequestComments.ts index 94c85705732..ddd1690a17b 100644 --- a/clients/admin-ui/src/features/privacy-requests/events-and-logs/hooks/usePrivacyRequestComments.ts +++ b/clients/admin-ui/src/features/privacy-requests/events-and-logs/hooks/usePrivacyRequestComments.ts @@ -51,7 +51,10 @@ export const usePrivacyRequestComments = (privacyRequestId: string) => { return { author, date: new Date(comment.created_at), - type: ActivityTimelineItemTypeEnum.INTERNAL_COMMENT, + type: + comment.user_id === null && comment.username !== "root_user" + ? ActivityTimelineItemTypeEnum.INTERNAL_AUTOMATION_COMMENT + : ActivityTimelineItemTypeEnum.INTERNAL_COMMENT, showViewLog: false, description: comment.comment_text, isError: false, diff --git a/clients/admin-ui/src/features/privacy-requests/types.ts b/clients/admin-ui/src/features/privacy-requests/types.ts index 4b40df9115a..b1bb5483546 100644 --- a/clients/admin-ui/src/features/privacy-requests/types.ts +++ b/clients/admin-ui/src/features/privacy-requests/types.ts @@ -239,6 +239,7 @@ export interface ConfigMessagingSecretsRequest { export enum ActivityTimelineItemTypeEnum { REQUEST_UPDATE = "Request update", INTERNAL_COMMENT = "Internal comment", + INTERNAL_AUTOMATION_COMMENT = "Internal automation comment", MANUAL_TASK = "Manual task", } @@ -248,6 +249,8 @@ export const TimelineItemColorMap: Record< > = { [ActivityTimelineItemTypeEnum.REQUEST_UPDATE]: CUSTOM_TAG_COLOR.DEFAULT, [ActivityTimelineItemTypeEnum.INTERNAL_COMMENT]: CUSTOM_TAG_COLOR.MARBLE, + [ActivityTimelineItemTypeEnum.INTERNAL_AUTOMATION_COMMENT]: + CUSTOM_TAG_COLOR.MARBLE, [ActivityTimelineItemTypeEnum.MANUAL_TASK]: CUSTOM_TAG_COLOR.NECTAR, }; diff --git a/design-docs/dynamic-database-credentials.md b/design-docs/dynamic-database-credentials.md index 737162bd9ab..f8b4623c9e0 100644 --- a/design-docs/dynamic-database-credentials.md +++ b/design-docs/dynamic-database-credentials.md @@ -107,7 +107,7 @@ This depends on a SQLAlchemy internal (`greenlet_spawn`), which is acceptable be - The planned SQLAlchemy 2.0 upgrade will replace this with the public `async_creator` API. - The code should include a clear TODO and comments explaining this constraint. -The module-level engines in `ctl_session.py` need to be refactored into lazy factories (similar to how `session_management.py` already works) so the `creator` can be injected at construction time. +The module-level engines in `ctl_session.py` remain as module-level singletons. The `creator` closure captures a provider reference, not credentials themselves — credentials are resolved inside the closure body on every call. This means the engine can be constructed at any time (including module import) and credential rotation still works correctly. ### 4. Automatic Retry on Auth Failure diff --git a/noxfiles/ci_nox.py b/noxfiles/ci_nox.py index ad983e094fe..f58b8d46256 100644 --- a/noxfiles/ci_nox.py +++ b/noxfiles/ci_nox.py @@ -549,6 +549,7 @@ def pytest_redis_cluster_docker(session: nox.Session) -> None: TEST_DIRECTORY_COVERAGE = { "tests/api/": ["api"], "tests/common/": ["misc-unit"], + "tests/config/": ["misc-unit"], "tests/ctl/": ["ctl-unit", "ctl-not-external", "ctl-integration", "ctl-external"], "tests/lib/": ["lib"], "tests/ops/": [ diff --git a/noxfiles/setup_tests_nox.py b/noxfiles/setup_tests_nox.py index e118a264269..956eda7af39 100644 --- a/noxfiles/setup_tests_nox.py +++ b/noxfiles/setup_tests_nox.py @@ -498,6 +498,7 @@ def pytest_misc_unit(session: Session, pytest_config: PytestConfig) -> None: "pytest", *pytest_config.args, "tests/common/", + "tests/config/", "tests/service/", "tests/system_integration_link/", "tests/task/", diff --git a/pyproject.toml b/pyproject.toml index d9f6c296d85..e2c9cddc776 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,7 @@ dev = [ "Faker==14.1.0", "freezegun==1.5.5", "GitPython==3.1.41", - "moto[s3]==5.1.22", + "moto[s3,secretsmanager]==5.1.22", "mypy==1.10.0", "nox>=2025.11", "pre-commit==2.20.0", diff --git a/src/fides/api/alembic/migrations/env.py b/src/fides/api/alembic/migrations/env.py index 78ed5db15b2..f52ed30e79a 100644 --- a/src/fides/api/alembic/migrations/env.py +++ b/src/fides/api/alembic/migrations/env.py @@ -2,10 +2,11 @@ from alembic import context from loguru import logger as log -from sqlalchemy import engine_from_config, pool, text +from sqlalchemy import create_engine, pool, text from fides.api.db.database import include_object from fides.api.util.logger import setup as setup_fidesapi_logger +from fides.common.engine_creators import SYNC_DIALECT_URL, db_cred_provider, make_sync_creator from fides.config import CONFIG # this is the Alembic Config object, which provides @@ -44,7 +45,7 @@ def run_migrations_offline(): script output. """ - url = fides_config.database.sync_database_uri + url = db_cred_provider.get_database_url() context.configure( url=url, target_metadata=target_metadata, @@ -71,11 +72,9 @@ def run_migrations_online(): and associate a connection with the context. """ - configuration = alembic_config.get_section(alembic_config.config_ini_section) - configuration["sqlalchemy.url"] = fides_config.database.sync_database_uri - connectable = engine_from_config( - configuration, - prefix="sqlalchemy.", + connectable = create_engine( + SYNC_DIALECT_URL, + creator=make_sync_creator(), poolclass=pool.NullPool, ) diff --git a/src/fides/api/app_setup.py b/src/fides/api/app_setup.py index 15a5ca07ec4..c49efea380b 100644 --- a/src/fides/api/app_setup.py +++ b/src/fides/api/app_setup.py @@ -59,6 +59,7 @@ ExceptionHandlers, response_validation_error_handler, ) +from fides.common.engine_creators import db_cred_provider from fides.common.session_management import get_api_session, get_autoclose_db_session from fides.config import CONFIG from fides.config.config_proxy import ConfigProxy @@ -205,12 +206,13 @@ async def run_database_startup(app: FastAPI) -> None: application webserver. """ - if not CONFIG.database.sync_database_uri: + database_url = db_cred_provider.get_database_url() + if not database_url: raise FidesError("No database uri provided") if CONFIG.database.automigrate: try: - configure_db(CONFIG.database.sync_database_uri) + configure_db(database_url) if not CONFIG.test_mode: with get_autoclose_db_session() as session: seed_db(session) diff --git a/src/fides/api/db/ctl_session.py b/src/fides/api/db/ctl_session.py index 0737c06f3be..48763ad57d3 100644 --- a/src/fides/api/db/ctl_session.py +++ b/src/fides/api/db/ctl_session.py @@ -1,4 +1,3 @@ -import ssl from asyncio import Lock, gather from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from typing import Any, AsyncGenerator, Callable, Dict @@ -11,24 +10,22 @@ from fides.api.db.session import ExtendedSession from fides.api.db.util import custom_json_deserializer, custom_json_serializer +from fides.common.engine_creators import ( + ASYNC_DIALECT_URL, + SYNC_DIALECT_URL, + make_async_creator, + make_sync_creator, +) from fides.config import CONFIG # asyncio lock and flag for warming up the async pool ASYNC_READONLY_POOL_LOCK = Lock() ASYNC_READONLY_POOL_WARMED = False -# Associated with a workaround in fides.core.config.database_settings -# ref: https://github.com/sqlalchemy/sqlalchemy/discussions/5975 -connect_args: Dict[str, Any] = {} -if CONFIG.database.params.get("sslrootcert"): - ssl_ctx = ssl.create_default_context(cafile=CONFIG.database.params["sslrootcert"]) - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - connect_args["ssl"] = ssl_ctx - -# Parameters are hidden for security +# Primary async engine — credentials resolved per-connection via creator async_engine = create_async_engine( - CONFIG.database.async_database_uri, - connect_args=connect_args, + ASYNC_DIALECT_URL, + creator=make_async_creator(), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="AsyncEngine", @@ -37,6 +34,9 @@ pool_size=CONFIG.database.api_async_engine_pool_size, max_overflow=CONFIG.database.api_async_engine_max_overflow, pool_pre_ping=CONFIG.database.api_async_engine_pool_pre_ping, + pool_recycle=CONFIG.database.pool_recycle + if CONFIG.database.pool_recycle is not None + else -1, # -1 is SQLAlchemy's default (no recycling) ) async_session_factory = sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False @@ -49,21 +49,12 @@ if CONFIG.database.async_readonly_database_uri: logger.info("Creating read-only async engine and session factory") - # Build connect_args for readonly (similar to primary) - readonly_connect_args: Dict[str, Any] = {} - readonly_params = CONFIG.database.readonly_params or {} - - if readonly_params.get("sslrootcert"): - ssl_ctx = ssl.create_default_context(cafile=readonly_params["sslrootcert"]) - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - readonly_connect_args["ssl"] = ssl_ctx - logger.info( f"Read-only async settings: max-overflow: {CONFIG.database.api_async_engine_max_overflow}, pool-size: {CONFIG.database.async_readonly_database_pool_size}, pre-warm = {CONFIG.database.async_readonly_database_prewarm}, autocommit = {CONFIG.database.async_readonly_database_autocommit}, skip rollback = {CONFIG.database.async_readonly_database_pool_skip_rollback}" ) readonly_async_engine = create_async_engine( - CONFIG.database.async_readonly_database_uri, - connect_args=readonly_connect_args, + ASYNC_DIALECT_URL, + creator=make_async_creator(readonly=True), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="ReadOnlyAsyncEngine", @@ -72,6 +63,9 @@ pool_size=CONFIG.database.async_readonly_database_pool_size, max_overflow=CONFIG.database.async_readonly_database_max_overflow, pool_pre_ping=CONFIG.database.async_readonly_database_pre_ping, + pool_recycle=CONFIG.database.pool_recycle + if CONFIG.database.pool_recycle is not None + else -1, # -1 is SQLAlchemy's default (no recycling) # Don't rollback before returning a connection to the pool - this improves performance dramatically; # can be turned off via config but the default is to not reset on return pool_reset_on_return=( @@ -92,7 +86,8 @@ # and they do not respect engine settings like pool_size, max_overflow, etc. # these should be removed, and we should standardize on what's provided in `session.py` sync_engine = create_engine( - CONFIG.database.sync_database_uri, + SYNC_DIALECT_URL, + creator=make_sync_creator(), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="SyncEngine", diff --git a/src/fides/api/db/session.py b/src/fides/api/db/session.py index f16eb8684af..dab4ddf7633 100644 --- a/src/fides/api/db/session.py +++ b/src/fides/api/db/session.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional from loguru import logger from sqlalchemy import create_engine @@ -11,6 +11,7 @@ from fides.api.common_exceptions import MissingConfig from fides.api.db.util import custom_json_deserializer, custom_json_serializer +from fides.common.engine_creators import SYNC_DIALECT_URL, make_sync_creator from fides.config import FidesConfig @@ -18,46 +19,71 @@ def get_db_engine( *, config: FidesConfig | None = None, database_uri: str | URL | None = None, + creator: Callable[[], Any] | None = None, pool_size: int = 50, max_overflow: int = 50, keepalives_idle: int | None = None, keepalives_interval: int | None = None, keepalives_count: int | None = None, pool_pre_ping: bool = True, + pool_recycle: Optional[int] = None, disable_pooling: bool = False, ) -> Engine: """Return a database engine. + When *creator* is provided, it is called by the pool to open each new + connection — credentials and connect_args are handled inside the creator. + A dialect-only URL is used for engine construction. + + When *database_uri* or *config* is provided, the engine uses a fixed + connection URI (existing behavior). + If the TESTING environment var is set the database engine returned will be connected to the test DB. """ - if not config and not database_uri: - raise ValueError("Either a config or database_uri is required") - - if not database_uri and config: - # Don't override any database_uri explicitly passed in - if config.test_mode: - database_uri = config.database.sqlalchemy_test_database_uri - else: - database_uri = config.database.sqlalchemy_database_uri - engine_args: Dict[str, Any] = { "json_serializer": custom_json_serializer, "json_deserializer": custom_json_deserializer, } - # keepalives settings - connect_args = {} - if keepalives_idle: - connect_args["keepalives_idle"] = keepalives_idle - if keepalives_interval: - connect_args["keepalives_interval"] = keepalives_interval - if keepalives_count: - connect_args["keepalives_count"] = keepalives_count - - if connect_args: - connect_args["keepalives"] = 1 - engine_args["connect_args"] = connect_args + if creator: + # Creator handles credentials and connect_args internally, + # so creator needs to set keepalives settings. + if database_uri or config: + raise ValueError( + "database_uri/config cannot be used with creator — " + "the creator handles connection construction" + ) + if keepalives_idle or keepalives_interval or keepalives_count: + raise ValueError( + "keepalives_idle/interval/count cannot be used with creator — " + "pass them as connect_args to the creator instead" + ) + engine_args["creator"] = creator + database_uri = SYNC_DIALECT_URL + else: + # URI-based path. + if not config and not database_uri: + raise ValueError("Either a config, database_uri, or creator is required") + + if not database_uri and config: + if config.test_mode: + database_uri = config.database.sqlalchemy_test_database_uri + else: + database_uri = config.database.sqlalchemy_database_uri + + # keepalives settings (only for URI path; creator handles its own) + connect_args = {} + if keepalives_idle: + connect_args["keepalives_idle"] = keepalives_idle + if keepalives_interval: + connect_args["keepalives_interval"] = keepalives_interval + if keepalives_count: + connect_args["keepalives_count"] = keepalives_count + + if connect_args: + connect_args["keepalives"] = 1 + engine_args["connect_args"] = connect_args if disable_pooling: engine_args["poolclass"] = NullPool @@ -65,24 +91,26 @@ def get_db_engine( engine_args["pool_pre_ping"] = pool_pre_ping engine_args["pool_size"] = pool_size engine_args["max_overflow"] = max_overflow + if pool_recycle is not None: + engine_args["pool_recycle"] = pool_recycle return create_engine(database_uri, **engine_args) def get_db_session( - config: FidesConfig, + config: FidesConfig, # TODO: remove — no longer used, all callers pass CONFIG autocommit: bool = False, autoflush: bool = False, engine: Engine | None = None, ) -> sessionmaker: """Return a database SessionLocal.""" - if not config.database.sqlalchemy_database_uri: - raise MissingConfig("No database uri available in the config") + if engine is None: + engine = get_db_engine(creator=make_sync_creator()) return sessionmaker( autocommit=autocommit, autoflush=autoflush, - bind=engine or get_db_engine(config=config), + bind=engine, class_=ExtendedSession, ) diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index cf113747a02..9951b615874 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -19,6 +19,7 @@ from fides.api.request_context import get_request_id, set_request_id from fides.api.tasks import celery_healthcheck from fides.api.util.logger import setup as setup_logging +from fides.common.engine_creators import make_sync_creator from fides.config import CONFIG, FidesConfig MESSAGING_QUEUE_NAME = "fidesops.messaging" @@ -77,13 +78,18 @@ def get_new_session(self) -> ContextManager[Session]: # once per celery process. if self._task_engine is None: self._task_engine = get_db_engine( - config=CONFIG, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.task_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.task_engine_keepalives_interval, + "keepalives_count": CONFIG.database.task_engine_keepalives_count, + }, + ), pool_size=CONFIG.database.task_engine_pool_size, max_overflow=CONFIG.database.task_engine_max_overflow, - keepalives_idle=CONFIG.database.task_engine_keepalives_idle, - keepalives_interval=CONFIG.database.task_engine_keepalives_interval, - keepalives_count=CONFIG.database.task_engine_keepalives_count, pool_pre_ping=CONFIG.database.task_engine_pool_pre_ping, + pool_recycle=CONFIG.database.pool_recycle, ) # same for the sessionmaker diff --git a/src/fides/api/v1/endpoints/admin.py b/src/fides/api/v1/endpoints/admin.py index b15cb55b7fe..830cc65a264 100644 --- a/src/fides/api/v1/endpoints/admin.py +++ b/src/fides/api/v1/endpoints/admin.py @@ -27,6 +27,7 @@ ) from fides.api.v1.endpoints import API_PREFIX from fides.common import scope_registry +from fides.common.engine_creators import db_cred_provider from fides.common.scope_registry import BACKFILL_EXEC, HEAP_DUMP_EXEC from fides.config import CONFIG @@ -58,10 +59,12 @@ def db_action(action: DBActions, revision: Optional[str] = "head") -> Dict: explicit guidance from Ethyca support. """ + database_url = db_cred_provider.get_database_url() + if action == DBActions.downgrade: try: migrate_db( - database_url=CONFIG.database.sync_database_uri, + database_url=database_url, revision=revision, # type: ignore[arg-type] downgrade=True, ) @@ -87,12 +90,12 @@ def db_action(action: DBActions, revision: Optional[str] = "head") -> Dict: detail="Resetting the application database outside of dev_mode is not supported.", ) - reset_db(CONFIG.database.sync_database_uri) + reset_db(database_url) action_text = "reset" try: logger.info("Database being configured...") - configure_db(CONFIG.database.sync_database_uri, revision=revision) + configure_db(database_url, revision=revision) except Exception as e: logger.exception("Database configuration failed: {e}") raise HTTPException( diff --git a/src/fides/api/v1/endpoints/health.py b/src/fides/api/v1/endpoints/health.py index 832348c63c6..8f7cf1d0451 100644 --- a/src/fides/api/v1/endpoints/health.py +++ b/src/fides/api/v1/endpoints/health.py @@ -25,6 +25,7 @@ from fides.api.util.api_router import APIRouter from fides.api.util.cache import get_cache, get_queue_counts from fides.api.util.logger import Pii +from fides.common.engine_creators import db_cred_provider from fides.common.session_management import get_readonly_api_session from fides.config import CONFIG @@ -154,7 +155,7 @@ async def database_health(db: Session = Depends(get_db)) -> Dict: async_readonly_pool_prewarmed: Optional[bool] = None migration_health, current_revision = get_db_health( - CONFIG.database.sync_database_uri, db=db + db_cred_provider.get_database_url(), db=db ) # Primary sync pool (already checked out by dependency-injected session). diff --git a/src/fides/common/db_credential_provider.py b/src/fides/common/db_credential_provider.py new file mode 100644 index 00000000000..6ec0cffaaac --- /dev/null +++ b/src/fides/common/db_credential_provider.py @@ -0,0 +1,248 @@ +"""Database credential resolution with retry-on-auth-failure and exception sanitization. + +DBCredentialProvider sits between SQLAlchemy engine creators and the SecretProvider +abstraction. It resolves credentials (from static config or a secret store), wraps +connection attempts with a single retry on authentication failure during credential +rotation, and sanitizes all connection exceptions to prevent credential leakage. + +Driver-agnostic: does not import psycopg2 or asyncpg. +""" + +from __future__ import annotations + +import time +from typing import Any, Callable, Dict, Optional, TypeVar +from urllib.parse import quote, quote_plus, urlencode + +from loguru import logger as log +from psycopg2 import ( # type: ignore[import-untyped] + OperationalError as Psycopg2OperationalError, +) + +from fides.config import CONFIG +from fides.config.secrets import StaticSecretProvider, get_secret_provider +from fides.config.secrets.static_provider import ( + DATABASE_CREDENTIALS_KEY, + DATABASE_READONLY_CREDENTIALS_KEY, +) + +__all__ = ["DBCredentialProvider", "SanitizedConnectionError"] + +T = TypeVar("T") + +_AUTH_SQLSTATES = frozenset({"28P01", "28000"}) +_AUTH_RETRY_DELAY = 1.5 # seconds — propagation window for AWS rotation + + +class SanitizedConnectionError(Exception): + """Connection error with credentials stripped. Safe to log and report.""" + + __slots__ = ("sqlstate",) + + def __init__(self, message: str, sqlstate: Optional[str] = None) -> None: + super().__init__(message) + self.sqlstate = sqlstate + + +class DBCredentialProvider: + """Resolves database credentials and wraps connection attempts + with retry-on-auth-failure and exception sanitization. + """ + + def __init__(self) -> None: + self._provider = get_secret_provider() + + # ------------------------------------------------------------------ + # Credential resolution + # ------------------------------------------------------------------ + + @property + def is_dynamic(self) -> bool: + """True when using a non-static provider (credentials can rotate).""" + return not isinstance(self._provider, StaticSecretProvider) + + def _get_secret_id(self, readonly: bool) -> str: + """Resolve which secret ID to use. + + For dynamic providers: uses configured secret IDs with readonly -> primary fallback. + For static provider: uses the well-known default keys. + """ + if self.is_dynamic: + credential_secret_name = CONFIG.database.credential_secret_name + if credential_secret_name is None: + raise ValueError( + "secrets.provider is not 'static' but " + "database.credential_secret_name is not set." + ) + if readonly: + return ( + CONFIG.database.readonly_credential_secret_name + or credential_secret_name + ) + return credential_secret_name + else: + if readonly and CONFIG.database.readonly_server: + return DATABASE_READONLY_CREDENTIALS_KEY + return DATABASE_CREDENTIALS_KEY + + def get_credentials(self, readonly: bool = False) -> Dict[str, Any]: + """Return ``{host, port, user, password, dbname}`` for a database connection.""" + db = CONFIG.database + + # Host, port, dbname always from config + creds = { + "host": db.server, + "port": int(db.port), + "dbname": db.test_db if CONFIG.test_mode else db.db, + } + if readonly and db.readonly_server: + creds["host"] = db.readonly_server + creds["port"] = int(db.readonly_port or db.port) + creds["dbname"] = db.readonly_db or db.db + + # Get credentials (user/password) from the provider + secret = self._provider.get_secret(self._get_secret_id(readonly)) + creds["user"] = secret["username"] + creds["password"] = secret["password"] + + return creds + + def get_database_url( + self, driver: str = "postgresql+psycopg2", readonly: bool = False + ) -> str: + """Build a SQLAlchemy database URL with credentials from the provider. + + Includes connection params (SSL, keepalives, etc.) from CONFIG.database.params + as query parameters, matching the old sync_database_uri behavior. + """ + creds = self.get_credentials(readonly=readonly) + user = quote_plus(creds["user"]) + password = quote_plus(creds["password"]) + url = f"{driver}://{user}:{password}@{creds['host']}:{creds['port']}/{creds['dbname']}" + + params = CONFIG.database.params + if params: + url += "?" + urlencode(params, quote_via=quote, safe="/") + return url + + # ------------------------------------------------------------------ + # Connection with retry + # ------------------------------------------------------------------ + + def connect_with_retry( + self, + connect_fn: Callable[..., T], + connect_kwargs: Dict[str, Any], + readonly: bool = False, + ) -> T: + """Attempt a connection; on auth failure with dynamic credentials, retry once. + + Args: + connect_fn: Driver connect callable (e.g. ``psycopg2.connect``). + connect_kwargs: Non-credential connection kwargs (keepalives, SSL). + Credentials are merged in from ``get_credentials()``. + readonly: Whether to use readonly credentials. + + Returns: + The raw connection object. + + Raises: + SanitizedConnectionError: On any connection failure, with + connection parameters stripped from the exception message. + """ + creds = self.get_credentials(readonly=readonly) + kwargs = {**connect_kwargs, **creds} + + try: + return connect_fn(**kwargs) + except Exception as exc: + if self.is_dynamic and self._is_auth_error(exc): + return self._retry_with_fresh_credentials( + connect_fn, connect_kwargs, readonly, exc + ) + raise self._sanitize_exception(exc) from None + + def _retry_with_fresh_credentials( + self, + connect_fn: Callable[..., T], + connect_kwargs: Dict[str, Any], + readonly: bool, + original_exc: Exception, + ) -> T: + """Invalidate the cached secret, wait for propagation, retry once.""" + secret_id = self._get_secret_id(readonly) + + log.warning( + "Connection failure ({}: SQLSTATE {}), invalidating secret {!r} and retrying", + type(original_exc).__name__, + self._extract_sqlstate(original_exc), + secret_id, + ) + self._provider.invalidate(secret_id) + + # Safe in both sync and async paths: async engine creators run inside + # SQLAlchemy's greenlet bridge, so time.sleep blocks the greenlet, not + # the event loop (same mechanism as await_only(asyncpg.connect(...))). + time.sleep(_AUTH_RETRY_DELAY) + + fresh_creds = self.get_credentials(readonly=readonly) + kwargs = {**connect_kwargs, **fresh_creds} + + try: + return connect_fn(**kwargs) + except Exception as exc: + log.error( + "Retry also failed (SQLSTATE {}), credentials may be wrong", + self._extract_sqlstate(exc), + ) + raise self._sanitize_exception(exc) from None + + # ------------------------------------------------------------------ + # Error detection and sanitization + # ------------------------------------------------------------------ + + @staticmethod + def _is_auth_error(exc: Exception) -> bool: + """Detect connection failures that may indicate credential rotation. + + Checks SQLSTATE codes first (asyncpg always provides these). + Falls back to message matching for psycopg2, which does not + populate pgcode on connection-time errors. The fallback string + comes from PostgreSQL's auth handshake, which is always English + (sent before any locale is configured). + + Also matches any psycopg2 OperationalError as a broad fallback, + because RDS Proxy and other managed PostgreSQL services may return + non-standard error messages on auth failure that don't match the + specific patterns above. The cost of retrying + on a non-auth OperationalError is one extra Secrets Manager call + and a 1.5s delay, which is acceptable given the alternative is + 15 minutes of 500s. + """ + if DBCredentialProvider._extract_sqlstate(exc) in _AUTH_SQLSTATES: + return True + if "password authentication failed" in str(exc).lower(): + return True + return isinstance(exc, Psycopg2OperationalError) + + @staticmethod + def _extract_sqlstate(exc: Exception) -> Optional[str]: + """Extract SQLSTATE from a driver exception for logging.""" + return getattr(exc, "pgcode", None) or getattr(exc, "sqlstate", None) + + @staticmethod + def _sanitize_exception(exc: Exception) -> SanitizedConnectionError: + """Replace a driver exception with a sanitized version. + + Constructs a new exception containing only the exception type name + and SQLSTATE code. Raised with ``from None`` by callers to break + the exception chain and prevent credential leakage through error + reporters that serialize ``__cause__``. + """ + sqlstate = DBCredentialProvider._extract_sqlstate(exc) + exc_type = type(exc).__name__ + if sqlstate: + msg = f"Database connection failed: {exc_type} (SQLSTATE {sqlstate})" + else: + msg = f"Database connection failed: {exc_type}" + return SanitizedConnectionError(msg, sqlstate=sqlstate) diff --git a/src/fides/common/engine_creators.py b/src/fides/common/engine_creators.py new file mode 100644 index 00000000000..5d8bbba6828 --- /dev/null +++ b/src/fides/common/engine_creators.py @@ -0,0 +1,167 @@ +""" +SQLAlchemy engine ``creator`` callables for dynamic credential resolution. + +The ``creator`` pattern passes a callable to ``create_engine`` / +``create_async_engine`` instead of a connection URI. SQLAlchemy calls the +creator every time the pool needs a new connection, so credentials are +resolved at **connection time** rather than engine construction time. + +Credential resolution, auth-failure retry, and exception sanitization are +handled by ``DBCredentialProvider``. + +Because creators run on every new pool connection, they must stay +lightweight — avoid expensive I/O, network calls, or heavy computation. +Credential lookups should return cached values in the common case. +""" + +from __future__ import annotations + +import ssl +from copy import deepcopy +from typing import Any, Callable, Dict, Optional + +import asyncpg # type: ignore[import-untyped] +import psycopg2 # type: ignore[import-untyped] +from sqlalchemy.dialects.postgresql.asyncpg import ( + AsyncAdapt_asyncpg_connection, + AsyncAdapt_asyncpg_dbapi, +) +from sqlalchemy.util.concurrency import await_only # type: ignore[import-untyped] + +from fides.common.db_credential_provider import DBCredentialProvider +from fides.config import CONFIG + +# Dialect-only URLs for the creator pattern — no credentials, just driver selection. +SYNC_DIALECT_URL = "postgresql+psycopg2://" +ASYNC_DIALECT_URL = "postgresql+asyncpg://" + +# Shared dbapi instance for async creators — reused across connections. +_asyncpg_dbapi = AsyncAdapt_asyncpg_dbapi(asyncpg) + +# Module-level provider — all engines share one instance (and one secret cache). +db_cred_provider = DBCredentialProvider() + + +# --------------------------------------------------------------------------- +# Sync creators (psycopg2) +# --------------------------------------------------------------------------- + + +def make_sync_creator( + connect_args: Optional[Dict[str, Any]] = None, + readonly: bool = False, +) -> Callable[[], Any]: + """Return a creator callable for psycopg2 engines. + + The factory captures per-engine config (keepalives, SSL) in the closure. + Credentials are resolved via ``DBCredentialProvider`` on every call, + with automatic retry on auth failure for dynamic credentials. + + When using ``creator``, SQLAlchemy ignores ``connect_args`` passed to + ``create_engine``, so all connection parameters must be baked in here. + """ + extra_kwargs = dict(connect_args) if connect_args else {} + + def creator() -> Any: + return db_cred_provider.connect_with_retry( + connect_fn=psycopg2.connect, + connect_kwargs=extra_kwargs, + readonly=readonly, + ) + + return creator + + +# --------------------------------------------------------------------------- +# Async creators (asyncpg) +# --------------------------------------------------------------------------- + + +def make_async_creator( + readonly: bool = False, +) -> Callable[[], Any]: + """Return a creator callable for asyncpg engines (SA 1.4.27). + + The factory builds the SSL context and asyncpg-compatible params from + CONFIG, capturing them in the closure. Credentials are resolved via + ``DBCredentialProvider`` on every call. + + The creator replaces ``dialect.connect()`` in SQLAlchemy's pool. For + async engines the pool runs inside a greenlet bridge, so ``await_only`` + is valid. Must return ``AsyncAdapt_asyncpg_connection`` (SA's sync + wrapper) since the pool operates in sync mode through greenlets. + + TODO: Replace with ``async_creator`` API after SQLAlchemy 2.0 upgrade. + """ + db_params = ( + (CONFIG.database.readonly_params or CONFIG.database.params) + if readonly + else CONFIG.database.params + ) + ssl_context = _build_ssl_context(db_params) + async_params = _convert_asyncpg_params(db_params) + + # When we have a full SSLContext (from sslrootcert), it takes priority + # over the raw ssl string (from sslmode). Otherwise kw.update(async_params) + # would overwrite the SSLContext with e.g. "require", losing cert verification. + if ssl_context: + async_params.pop("ssl", None) + + extra_kwargs: Dict[str, Any] = {} + if ssl_context: + extra_kwargs["ssl"] = ssl_context + if async_params: + extra_kwargs.update(async_params) + + def _connect_asyncpg(**kwargs: Any) -> AsyncAdapt_asyncpg_connection: + kw = { + "host": kwargs.pop("host"), + "port": kwargs.pop("port"), + "user": kwargs.pop("user"), + "password": kwargs.pop("password"), + "database": kwargs.pop("dbname"), + } + kw.update(kwargs) + raw_conn = await_only(asyncpg.connect(**kw)) + return AsyncAdapt_asyncpg_connection(_asyncpg_dbapi, raw_conn) + + def creator() -> Any: + return db_cred_provider.connect_with_retry( + connect_fn=_connect_asyncpg, + connect_kwargs=extra_kwargs, + readonly=readonly, + ) + + return creator + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _build_ssl_context(params: Dict[str, Any]) -> Optional[ssl.SSLContext]: + """Build an ``ssl.SSLContext`` from DB params if ``sslrootcert`` is set.""" + sslrootcert = params.get("sslrootcert") + if not sslrootcert: + return None + ctx = ssl.create_default_context(cafile=sslrootcert) + ctx.verify_mode = ssl.CERT_REQUIRED + return ctx + + +def _convert_asyncpg_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Convert DB params dict for asyncpg compatibility. + + asyncpg uses ``ssl`` instead of ``sslmode`` and does not accept + ``sslrootcert`` as a connection parameter (it's handled via + ``ssl.SSLContext`` passed separately). + + See: https://github.com/MagicStack/asyncpg/issues/737 + ref: https://github.com/sqlalchemy/sqlalchemy/discussions/5975 + """ + converted = deepcopy(params) + if "sslmode" in converted: + converted["ssl"] = converted.pop("sslmode") + converted.pop("sslrootcert", None) + return converted diff --git a/src/fides/common/session_management.py b/src/fides/common/session_management.py index ac0844195bc..21c17af14e4 100644 --- a/src/fides/common/session_management.py +++ b/src/fides/common/session_management.py @@ -16,6 +16,7 @@ from fides.api.db.ctl_session import async_session from fides.api.db.session import get_db_engine, get_db_session +from fides.common.engine_creators import make_sync_creator from fides.config import CONFIG T = TypeVar("T") @@ -28,13 +29,18 @@ def get_api_session() -> Session: global _engine # pylint: disable=W0603 if not _engine: _engine = get_db_engine( - config=CONFIG, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.api_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.api_engine_keepalives_interval, + "keepalives_count": CONFIG.database.api_engine_keepalives_count, + }, + ), pool_size=CONFIG.database.api_engine_pool_size, max_overflow=CONFIG.database.api_engine_max_overflow, - keepalives_idle=CONFIG.database.api_engine_keepalives_idle, - keepalives_interval=CONFIG.database.api_engine_keepalives_interval, - keepalives_count=CONFIG.database.api_engine_keepalives_count, pool_pre_ping=CONFIG.database.api_engine_pool_pre_ping, + pool_recycle=CONFIG.database.pool_recycle, ) SessionLocal = get_db_session(CONFIG, engine=_engine) return SessionLocal() @@ -140,13 +146,19 @@ def get_readonly_api_session() -> Session: global _readonly_engine # pylint: disable=W0603 if not _readonly_engine: _readonly_engine = get_db_engine( - database_uri=CONFIG.database.sqlalchemy_readonly_database_uri, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.api_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.api_engine_keepalives_interval, + "keepalives_count": CONFIG.database.api_engine_keepalives_count, + }, + readonly=True, + ), pool_size=CONFIG.database.api_engine_pool_size, max_overflow=CONFIG.database.api_engine_max_overflow, - keepalives_idle=CONFIG.database.api_engine_keepalives_idle, - keepalives_interval=CONFIG.database.api_engine_keepalives_interval, - keepalives_count=CONFIG.database.api_engine_keepalives_count, pool_pre_ping=CONFIG.database.api_engine_pool_pre_ping, + pool_recycle=CONFIG.database.pool_recycle, ) SessionLocal = get_db_session(CONFIG, engine=_readonly_engine) return SessionLocal() diff --git a/src/fides/config/__init__.py b/src/fides/config/__init__.py index 1d5a65a51d6..97c94416fb7 100644 --- a/src/fides/config/__init__.py +++ b/src/fides/config/__init__.py @@ -9,7 +9,7 @@ import toml from loguru import logger as log -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -32,6 +32,7 @@ from .notification_settings import NotificationSettings from .privacy_center_settings import PrivacyCenterSettings from .redis_settings import RedisSettings +from .secrets_settings import SecretsSettings from .security_settings import SecuritySettings from .user_settings import UserSettings from .utils import ( @@ -86,11 +87,37 @@ class FidesConfig(FidesSettings): notifications: NotificationSettings redis: RedisSettings privacy_center: PrivacyCenterSettings + secrets: SecretsSettings security: SecuritySettings user: UserSettings model_config = SettingsConfigDict(case_sensitive=True) + @model_validator(mode="after") + def _validate_database_credential_secret_names(self) -> "FidesConfig": + """Validate database credential secret IDs against the secrets provider.""" + if self.secrets.provider == "static": + if self.database.credential_secret_name: + raise ValueError( + f"database.credential_secret_name is set ({self.database.credential_secret_name!r}) " + "but secrets.provider is 'static'. Either remove the secret ID " + "or set secrets.provider to 'aws_secrets_manager'." + ) + if self.database.readonly_credential_secret_name: + raise ValueError( + f"database.readonly_credential_secret_name is set ({self.database.readonly_credential_secret_name!r}) " + "but secrets.provider is 'static'. Either remove the secret ID " + "or set secrets.provider to 'aws_secrets_manager'." + ) + else: + if not self.database.credential_secret_name: + raise ValueError( + f"secrets.provider is {self.secrets.provider!r} but " + "database.credential_secret_name is not set. " + "Provide the secret name/ARN containing database credentials." + ) + return self + @classmethod def settings_customise_sources( cls, @@ -175,6 +202,7 @@ def build_config(config_dict: Dict[str, Any]) -> FidesConfig: "notifications": NotificationSettings, "privacy_center": PrivacyCenterSettings, "redis": RedisSettings, + "secrets": SecretsSettings, "security": SecuritySettings, "user": UserSettings, } diff --git a/src/fides/config/create.py b/src/fides/config/create.py index 7a01948661d..b2598bdc9bd 100644 --- a/src/fides/config/create.py +++ b/src/fides/config/create.py @@ -71,7 +71,9 @@ def build_field_documentation(field_name: str, field_info: Dict) -> Optional[str # Union field types are under "anyOf" any_of: List[Dict[str, str]] = field_info.get("anyOf") or [] for type_annotation in any_of: - if type_annotation["type"] != "null": + if "$ref" in type_annotation: + continue + if type_annotation.get("type") != "null": # Getting first non-null field_type = type_annotation["type"] break diff --git a/src/fides/config/database_settings.py b/src/fides/config/database_settings.py index 53144552453..7390dc90f6e 100644 --- a/src/fides/config/database_settings.py +++ b/src/fides/config/database_settings.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Dict, Optional, cast -from urllib.parse import quote, quote_plus, urlencode +from urllib.parse import quote, quote_plus, unquote_plus, urlencode from pydantic import ( Field, @@ -55,6 +55,18 @@ class DatabaseSettings(FidesSettings): description="If true, the engine will pre-ping connections to ensure they are still valid before using them.", ) + # Pool Recycle (applies to all engines) + pool_recycle: Optional[int] = Field( + default=None, + gt=0, + description=( + "Number of seconds after which a database connection is automatically " + "recycled (closed and replaced). Useful when a connection proxy or " + "firewall imposes an idle connection timeout. Set this to a value lower " + "than the proxy/DB timeout. When unset (None), connections are never recycled." + ), + ) + # Async Engine Settings # Note: We purposely do not include async engine equivalents of the sync engine's # keepalives_* settings as they are not supported by asyncpg. @@ -118,6 +130,15 @@ class DatabaseSettings(FidesSettings): description="Additional connection parameters for read-only database connections. If not provided and readonly_server is set, uses 'params'.", ) + credential_secret_name: Optional[str] = Field( + default=None, + description="Secrets Manager secret name or ARN containing DB credentials (JSON with 'username' and 'password' keys). Used when secrets.provider is 'aws_secrets_manager'.", + ) + readonly_credential_secret_name: Optional[str] = Field( + default=None, + description="Secrets Manager secret name or ARN for read-only DB credentials. Falls back to credential_secret_name if not set.", + ) + task_engine_pool_size: int = Field( default=50, description="Number of concurrent database connections Fides will use for executing privacy request tasks, either locally or on each worker. Note that the pool begins with no connections, but as they are requested the connections are maintained and reused up to this limit.", @@ -275,6 +296,18 @@ def escape_password(cls, value: Optional[str]) -> Optional[str]: return quote_plus(value) return value + @property + def raw_password(self) -> str: + """Return password unescaped for direct driver use (psycopg2/asyncpg).""" + return unquote_plus(self.password) + + @property + def raw_readonly_password(self) -> Optional[str]: + """Return readonly password unescaped for direct driver use.""" + if self.readonly_password: + return unquote_plus(self.readonly_password) + return None + @field_validator("sync_database_uri", mode="before") @classmethod def assemble_sync_database_uri( diff --git a/src/fides/config/secrets/__init__.py b/src/fides/config/secrets/__init__.py new file mode 100644 index 00000000000..398a9c822e8 --- /dev/null +++ b/src/fides/config/secrets/__init__.py @@ -0,0 +1,46 @@ +"""Secret provider abstraction for dynamically-resolved credentials.""" + +from typing import Optional + +from fides.config import CONFIG +from fides.config.secrets.aws_secrets_manager_provider import ( + AWSSecretsManagerProvider, +) +from fides.config.secrets.base import SecretProvider, SecretProviderError, SecretValue +from fides.config.secrets.factory import create_secret_provider +from fides.config.secrets.static_provider import StaticSecretProvider + +_provider: Optional[SecretProvider] = None + + +def get_secret_provider() -> SecretProvider: + """Return the application-wide SecretProvider singleton. + + Created lazily on first access from ``CONFIG.secrets``. All consumers + share one instance so the credential cache is coherent. + """ + global _provider + if _provider is None: + _provider = create_secret_provider(CONFIG.secrets) + return _provider + + +def reset_secret_provider() -> None: + """Reset the singleton to ``None``, forcing re-creation on next access. + + Intended for testing only. + """ + global _provider + _provider = None + + +__all__ = [ + "AWSSecretsManagerProvider", + "SecretProvider", + "SecretProviderError", + "SecretValue", + "StaticSecretProvider", + "create_secret_provider", + "get_secret_provider", + "reset_secret_provider", +] diff --git a/src/fides/config/secrets/aws_secrets_manager_provider.py b/src/fides/config/secrets/aws_secrets_manager_provider.py new file mode 100644 index 00000000000..88a86cd7664 --- /dev/null +++ b/src/fides/config/secrets/aws_secrets_manager_provider.py @@ -0,0 +1,221 @@ +"""AWS Secrets Manager provider with caching, stale-while-revalidate, +thundering-herd protection, and circuit breaker.""" + +import json +import threading +import time +from dataclasses import dataclass, field +from typing import Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from loguru import logger as log + +from fides.config.secrets.base import SecretProvider, SecretProviderError, SecretValue + +# Error codes indicating the secret is intentionally inaccessible. +# Serving stale credentials would mask a deliberate revocation, so these +# must fail immediately and clear the cache. +_PERMANENT_ERROR_CODES = { + "ResourceNotFoundException", # secret deleted + "AccessDeniedException", # IAM permissions revoked + "DecryptionFailureException", # KMS key disabled or deleted + "InvalidRequestException", # secret scheduled for deletion +} + + +@dataclass +class _CacheEntry: + """Per-secret cache state.""" + + value: Optional[SecretValue] + fetched_at: float + last_failed_at: float = 0.0 + lock: threading.Lock = field(default_factory=threading.Lock) + + +class AWSSecretsManagerProvider(SecretProvider): + """Fetches secrets from AWS Secrets Manager with local caching. + + Features: + - TTL-based cache with stale-while-revalidate fallback + - Per-secret locking for thundering-herd protection + - Circuit breaker to prevent retry amplification on bad credentials + """ + + def __init__( + self, + region_name: Optional[str] = None, + cache_ttl_seconds: float = 300.0, + cache_stale_ttl_seconds: float = 1800.0, + circuit_breaker_cooldown_seconds: float = 30.0, + endpoint_url: Optional[str] = None, + ) -> None: + self._cache_ttl = cache_ttl_seconds + self._cache_stale_ttl = cache_stale_ttl_seconds + self._circuit_breaker_cooldown = circuit_breaker_cooldown_seconds + + session = boto3.Session(region_name=region_name) + self._client = session.client( + "secretsmanager", + endpoint_url=endpoint_url, + ) + + self._cache: Dict[str, _CacheEntry] = {} + self._cache_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def get_secret(self, secret_id: str) -> SecretValue: + entry = self._get_or_create_entry(secret_id) + + with entry.lock: + now = time.monotonic() + + # Cache hit — return without network call + if ( + entry.value is not None + and entry.fetched_at > 0 + and (now - entry.fetched_at) < self._cache_ttl + ): + return entry.value + + # Circuit breaker: if we recently failed, serve cached value + # rather than hitting Secrets Manager again + if ( + entry.last_failed_at > 0 + and (now - entry.last_failed_at) < self._circuit_breaker_cooldown + and entry.value is not None + ): + return entry.value + + return self._fetch_and_update(secret_id, entry) + + def invalidate(self, secret_id: str) -> None: + with self._cache_lock: + entry = self._cache.get(secret_id) + + if entry is None: + return + + with entry.lock: + now = time.monotonic() + if ( + entry.last_failed_at > 0 + and (now - entry.last_failed_at) < self._circuit_breaker_cooldown + ): + log.debug( + "Circuit breaker active for {!r}, skipping invalidation", + secret_id, + ) + return + + entry.fetched_at = 0.0 + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_or_create_entry(self, secret_id: str) -> _CacheEntry: + with self._cache_lock: + entry = self._cache.get(secret_id) + if entry is None: + # Placeholder entry — fetched_at=0 forces a fetch on first access + entry = _CacheEntry( + value=None, + fetched_at=0.0, + ) + self._cache[secret_id] = entry + return entry + + def _fetch_and_update(self, secret_id: str, entry: _CacheEntry) -> SecretValue: + """Fetch from Secrets Manager, update cache, handle failures.""" + try: + new_value = self._fetch(secret_id) + except SecretProviderError: + raise # Don't wrap our own errors (e.g. binary secret, invalid JSON) + except Exception as exc: + return self._handle_fetch_failure(secret_id, entry, exc) + + entry.value = new_value + entry.fetched_at = time.monotonic() + entry.last_failed_at = 0.0 + log.info("Successfully fetched and cached secret {!r}", secret_id) + return new_value + + def _fetch(self, secret_id: str) -> SecretValue: + """Call AWS Secrets Manager and parse the response.""" + log.debug("Fetching secret {!r} from AWS Secrets Manager", secret_id) + response = self._client.get_secret_value( + SecretId=secret_id, + VersionStage="AWSCURRENT", + ) + if "SecretBinary" in response: + raise SecretProviderError( + f"Secret {secret_id!r} is stored as binary; " + f"only SecretString secrets are supported" + ) + secret_string = response["SecretString"] + try: + data = json.loads(secret_string) + except json.JSONDecodeError as exc: + # Don't chain the original exception — its .doc attribute + # contains the raw secret string, which could leak credentials + # if the exception is logged or inspected upstream. + raise SecretProviderError( + f"Secret {secret_id!r} is not valid JSON " + f"(parse error at line {exc.lineno}, column {exc.colno})" + ) from None + return SecretValue(data) + + def _handle_fetch_failure( + self, secret_id: str, entry: _CacheEntry, exc: Exception + ) -> SecretValue: + """Serve stale value if within grace period, otherwise raise. + + Permanent errors (secret deleted, IAM revoked, KMS key disabled) + clear the cache and raise immediately — serving stale credentials + would mask a deliberate revocation. + """ + entry.last_failed_at = time.monotonic() + + # Permanent errors: clear cache so subsequent calls also fail + if isinstance(exc, ClientError): + error_code = exc.response.get("Error", {}).get("Code", "") + if error_code in _PERMANENT_ERROR_CODES: + entry.value = None + raise SecretProviderError( + f"Secret {secret_id!r} is permanently inaccessible ({error_code})" + ) from exc + + cached_value = entry.value + if cached_value is None: + raise SecretProviderError( + f"Failed to fetch secret {secret_id!r} and no cached value available" + ) from exc + + now = time.monotonic() + # fetched_at may be 0 if invalidated, meaning we lost the original + # fetch timestamp and cannot bound the age. In that case we + # unconditionally prefer serving stale data over raising, since the + # caller (e.g. a connection creator) can still function with the + # old credentials until a successful refresh occurs. + age = now - entry.fetched_at if entry.fetched_at > 0 else self._cache_stale_ttl + if age < self._cache_ttl + self._cache_stale_ttl: + exc_summary = ( + f"{type(exc).__name__}({exc.response['Error']['Code']})" + if hasattr(exc, "response") + else type(exc).__name__ + ) + log.warning( + "Failed to refresh secret {!r}, serving stale value ({})", + secret_id, + exc_summary, + ) + return cached_value + + raise SecretProviderError( + f"Failed to fetch secret {secret_id!r} and stale cache has expired" + ) from exc diff --git a/src/fides/config/secrets/base.py b/src/fides/config/secrets/base.py new file mode 100644 index 00000000000..1bdee30da8d --- /dev/null +++ b/src/fides/config/secrets/base.py @@ -0,0 +1,67 @@ +"""Base classes for the secret provider abstraction.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class SecretProviderError(Exception): + """Raised when a secret provider operation fails.""" + + +class SecretValue: + """Wrapper around a secret dict that prevents accidental credential leakage. + + Supports subscript access (``secret["username"]``) but overrides string + coercion so credentials never appear in logs, tracebacks, or debug output. + + Uses ``__slots__`` to prevent ``vars()`` / ``__dict__`` access, which + blocks error reporters (Sentry, Datadog APM) from capturing the secret + data when serializing local variables on exception frames. + """ + + __slots__ = ("_data",) + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + def __reduce__(self) -> None: # type: ignore[override] + raise TypeError("SecretValue cannot be pickled") + + def __getstate__(self) -> None: + raise TypeError("SecretValue cannot be serialized") + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __contains__(self, key: object) -> bool: + return key in self._data + + def __eq__(self, other: object) -> bool: + if isinstance(other, SecretValue): + return self._data == other._data + return NotImplemented + + __hash__ = None # type: ignore[assignment] # unhashable by design + + def __iter__(self) -> None: + raise TypeError( + "SecretValue cannot be iterated — use 'key in sv' to check fields" + ) + + def __repr__(self) -> str: + return "" + + def __str__(self) -> str: + return "" + + +class SecretProvider(ABC): + """Abstract base class for secret providers.""" + + @abstractmethod + def get_secret(self, secret_id: str) -> SecretValue: + """Return the current value of a named secret.""" + + @abstractmethod + def invalidate(self, secret_id: str) -> None: + """Mark a cached secret as stale, forcing the next fetch to refresh.""" diff --git a/src/fides/config/secrets/factory.py b/src/fides/config/secrets/factory.py new file mode 100644 index 00000000000..812c2c16503 --- /dev/null +++ b/src/fides/config/secrets/factory.py @@ -0,0 +1,55 @@ +"""Factory for creating a SecretProvider from config settings.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from loguru import logger as log + +from fides.config.secrets.aws_secrets_manager_provider import ( + AWSSecretsManagerProvider, +) +from fides.config.secrets.base import SecretProvider, SecretProviderError +from fides.config.secrets.static_provider import StaticSecretProvider + +if TYPE_CHECKING: + from fides.config.secrets_settings import SecretsSettings + + +def create_secret_provider( + secrets_settings: SecretsSettings, +) -> SecretProvider: + """Instantiate the configured secret provider. + + Args: + secrets_settings: The ``[secrets]`` config section. + """ + provider_type = secrets_settings.provider + + if provider_type == "static": + log.info("Using static secret provider") + return StaticSecretProvider() + + if provider_type == "aws_secrets_manager": + aws = secrets_settings.aws_secrets_manager + if aws is None: + raise SecretProviderError( + "secrets.provider is 'aws_secrets_manager' but " + "secrets.aws_secrets_manager is not configured." + ) + log.info( + "Using AWS Secrets Manager provider (region={})", + aws.region, + ) + return AWSSecretsManagerProvider( + region_name=aws.region, + cache_ttl_seconds=aws.cache_ttl_seconds, + cache_stale_ttl_seconds=aws.cache_stale_ttl_seconds, + circuit_breaker_cooldown_seconds=aws.circuit_breaker_cooldown_seconds, + endpoint_url=aws.endpoint_url, + ) + + raise SecretProviderError( + f"Unknown secrets provider: {provider_type!r}. " + f"Must be 'static' or 'aws_secrets_manager'." + ) diff --git a/src/fides/config/secrets/static_provider.py b/src/fides/config/secrets/static_provider.py new file mode 100644 index 00000000000..56701d58ac8 --- /dev/null +++ b/src/fides/config/secrets/static_provider.py @@ -0,0 +1,47 @@ +"""Static secret provider — caches credentials from config at construction time.""" + +from typing import Dict + +from loguru import logger as log + +from fides.config import CONFIG +from fides.config.secrets.base import SecretProvider, SecretProviderError, SecretValue + +DATABASE_CREDENTIALS_KEY = "_db_credentials" +DATABASE_READONLY_CREDENTIALS_KEY = "_db_readonly_credentials" + + +class StaticSecretProvider(SecretProvider): + """Returns credentials read from config at construction time. + + Used for static credentials (environment variables / TOML config) + where rotation is not needed. Reads ``CONFIG.database`` once and + caches the values as ``SecretValue`` objects under well-known keys. + """ + + def __init__(self) -> None: + db = CONFIG.database + self._secrets: Dict[str, SecretValue] = { + DATABASE_CREDENTIALS_KEY: SecretValue( + {"username": db.user, "password": db.raw_password} + ), + } + if db.readonly_server: + self._secrets[DATABASE_READONLY_CREDENTIALS_KEY] = SecretValue( + { + "username": db.readonly_user, + "password": db.raw_readonly_password, + } + ) + + def get_secret(self, secret_id: str) -> SecretValue: + try: + return self._secrets[secret_id] + except KeyError: + raise SecretProviderError(f"Unknown secret_id: {secret_id!r}") from None + + def invalidate(self, secret_id: str) -> None: + log.debug( + "invalidate() called on StaticSecretProvider for {!r} (no-op)", + secret_id, + ) diff --git a/src/fides/config/secrets_settings.py b/src/fides/config/secrets_settings.py new file mode 100644 index 00000000000..6b7b7400feb --- /dev/null +++ b/src/fides/config/secrets_settings.py @@ -0,0 +1,72 @@ +"""Configuration settings for the secret provider subsystem.""" + +from typing import Literal, Optional + +from pydantic import Field, model_validator +from pydantic_settings import SettingsConfigDict + +from .fides_settings import FidesSettings + +ENV_PREFIX = "FIDES__SECRETS__" + + +class AWSSecretsManagerSettings(FidesSettings): + """Configuration for the AWS Secrets Manager provider.""" + + region: Optional[str] = Field( + default=None, + description="AWS region for Secrets Manager. If not set, uses the standard boto3 resolution chain (AWS_DEFAULT_REGION env var, ~/.aws/config profile, or EC2/EKS instance metadata).", + ) + cache_ttl_seconds: float = Field( + default=900.0, + description="TTL for cached secret values.", + ) + cache_stale_ttl_seconds: float = Field( + default=1800.0, + description="Grace period for serving last-known-good credentials when Secrets Manager is unreachable.", + ) + circuit_breaker_cooldown_seconds: float = Field( + default=30.0, + description="Cooldown window after a failed fetch before allowing another retry.", + ) + endpoint_url: Optional[str] = Field( + default=None, + description="Optional custom endpoint URL (e.g. LocalStack for local dev/CI).", + ) + + model_config = SettingsConfigDict( + env_prefix=f"{ENV_PREFIX}AWS_SECRETS_MANAGER__", + ) + + +class SecretsSettings(FidesSettings): + """Top-level configuration for the secrets provider.""" + + provider: Literal["static", "aws_secrets_manager"] = Field( + default="static", + description="Which secret provider to use: 'static' or 'aws_secrets_manager'.", + ) + aws_secrets_manager: Optional[AWSSecretsManagerSettings] = Field( + default=None, + description="AWS Secrets Manager configuration. Required when provider is 'aws_secrets_manager'.", + ) + + model_config = SettingsConfigDict(env_prefix=ENV_PREFIX) + + @model_validator(mode="before") + @classmethod + def _build_aws_settings_if_needed(cls, values: dict) -> dict: + """Construct AWS settings from env vars when provider is aws but no config was provided.""" + if values.get("provider") == "aws_secrets_manager" and not values.get( + "aws_secrets_manager" + ): + try: + values["aws_secrets_manager"] = AWSSecretsManagerSettings() + except Exception as exc: + raise ValueError( + "secrets.provider is 'aws_secrets_manager' but " + "secrets.aws_secrets_manager is not configured. " + "Provide the configuration via TOML or environment variables " + "(e.g. FIDES__SECRETS__AWS_SECRETS_MANAGER__REGION)." + ) from exc + return values diff --git a/tests/common/test_db_credential_provider.py b/tests/common/test_db_credential_provider.py new file mode 100644 index 00000000000..1a2e4a425d1 --- /dev/null +++ b/tests/common/test_db_credential_provider.py @@ -0,0 +1,485 @@ +"""Tests for DBCredentialProvider. + +Covers: +- Auth error detection (_is_auth_error / _extract_sqlstate) +- Exception sanitization (credential leakage prevention) +- Credential resolution (static, readonly, dynamic) +- Connection retry on auth failure (dynamic only) +""" + +from unittest.mock import MagicMock, patch + +import asyncpg +import psycopg2 +import pytest +from sqlalchemy.dialects.postgresql.asyncpg import ( + AsyncAdapt_asyncpg_connection, + AsyncAdapt_asyncpg_dbapi, +) +from sqlalchemy.util.concurrency import await_only + +from fides.common.db_credential_provider import ( + _AUTH_RETRY_DELAY, + DBCredentialProvider, + SanitizedConnectionError, +) +from fides.config import CONFIG +from fides.config.database_settings import DatabaseSettings +from fides.config.secrets.base import SecretValue +from fides.config.secrets.static_provider import StaticSecretProvider + +# --- Helpers --- + + +def _make_auth_error(pgcode=None, sqlstate=None): + """Create a mock exception with pgcode/sqlstate attributes.""" + exc = Exception("connection failed") + if pgcode is not None: + exc.pgcode = pgcode + if sqlstate is not None: + exc.sqlstate = sqlstate + return exc + + +def _make_password_leaking_error(password): + """Create an exception whose message contains the password.""" + exc = Exception( + f'FATAL: password authentication failed for user "myuser" ' + f"(password={password}, host=db.example.com)" + ) + exc.pgcode = "28P01" + return exc + + +# --- Fixtures --- + + +@pytest.fixture() +def static_provider(): + """DBCredentialProvider using the real CONFIG with a static provider.""" + yield DBCredentialProvider() + + +@pytest.fixture() +def dynamic_provider(): + """DBCredentialProvider backed by a mock SecretProvider.""" + with ( + patch("fides.common.db_credential_provider.get_secret_provider") as mock_get, + patch("fides.common.db_credential_provider.CONFIG") as mock_config, + patch("fides.common.db_credential_provider.time") as mock_time, + ): + mock_secret_provider = MagicMock() + mock_secret_provider.get_secret.return_value = SecretValue( + {"username": "secret_user", "password": "secret_pass"} + ) + mock_get.return_value = mock_secret_provider + mock_config.database = CONFIG.database + mock_config.database.credential_secret_name = "db-creds" + mock_config.database.readonly_credential_secret_name = None + mock_config.test_mode = CONFIG.test_mode + yield DBCredentialProvider(), mock_config, mock_secret_provider, mock_time + + +# --- Auth error detection --- + + +class TestIsAuthError: + @pytest.mark.parametrize( + "exc", + [ + _make_auth_error(pgcode="28P01"), + _make_auth_error(pgcode="28000"), + _make_auth_error(sqlstate="28P01"), + _make_auth_error(sqlstate="28000"), + ], + ids=["pgcode-28P01", "pgcode-28000", "sqlstate-28P01", "sqlstate-28000"], + ) + def test_detects_auth_sqlstates(self, exc): + assert DBCredentialProvider._is_auth_error(exc) + + @pytest.mark.parametrize( + "message", + [ + 'FATAL: password authentication failed for user "postgres"', + 'connection to server failed: FATAL: Password authentication failed for user "app"', + ], + ids=["standard-pg", "mixed-case"], + ) + def test_string_fallback_detects_auth_message(self, message): + """psycopg2 doesn't set pgcode on connection-time errors, + so _is_auth_error falls back to message matching.""" + exc = Exception(message) + assert DBCredentialProvider._is_auth_error(exc) + + @pytest.mark.parametrize( + "message", + [ + "connection refused", + "could not connect to server: Connection timed out", + 'FATAL: database "nope" does not exist', + ], + ids=["refused", "timeout", "db-not-found"], + ) + def test_string_fallback_rejects_non_auth_messages(self, message): + exc = Exception(message) + assert not DBCredentialProvider._is_auth_error(exc) + + def test_detects_psycopg2_operational_error(self): + """OperationalError from RDS Proxy may not have a standard auth message.""" + exc = psycopg2.OperationalError("proxy connection error") + assert DBCredentialProvider._is_auth_error(exc) + + @pytest.mark.parametrize( + "exc", + [ + _make_auth_error(pgcode="42P01"), + Exception("generic error"), + ], + ids=["non-auth-pgcode", "no-code-attributes"], + ) + def test_rejects_non_auth_errors(self, exc): + assert not DBCredentialProvider._is_auth_error(exc) + + +# --- Exception sanitization --- + + +class TestSanitizeException: + def test_includes_exception_type_and_sqlstate(self): + exc = _make_auth_error(pgcode="28P01") + sanitized = DBCredentialProvider._sanitize_exception(exc) + assert isinstance(sanitized, SanitizedConnectionError) + assert "Exception" in str(sanitized) + assert "28P01" in str(sanitized) + assert sanitized.sqlstate == "28P01" + + def test_without_sqlstate(self): + sanitized = DBCredentialProvider._sanitize_exception(RuntimeError("boom")) + assert "RuntimeError" in str(sanitized) + assert sanitized.sqlstate is None + + def test_password_not_in_sanitized_message(self): + password = "s3cret!p@ss" + exc = _make_password_leaking_error(password) + sanitized = DBCredentialProvider._sanitize_exception(exc) + assert password not in str(sanitized) + assert password not in repr(sanitized) + + def test_exception_chain_broken_with_from_none(self): + exc = _make_auth_error(pgcode="28P01") + sanitized = DBCredentialProvider._sanitize_exception(exc) + try: + raise sanitized from None + except SanitizedConnectionError as caught: + assert caught.__cause__ is None + assert caught.__context__ is None + + +# --- Credential resolution --- + + +class TestGetCredentials: + def test_static_returns_config_credentials(self, static_provider): + creds = static_provider.get_credentials() + assert set(creds.keys()) == {"host", "port", "user", "password", "dbname"} + assert creds["host"] == CONFIG.database.server + assert creds["port"] == int(CONFIG.database.port) + assert creds["user"] == CONFIG.database.user + assert isinstance(creds["port"], int) + + def test_static_is_not_dynamic(self, static_provider): + assert not static_provider.is_dynamic + + def test_readonly_falls_back_to_primary_when_no_readonly_server( + self, static_provider + ): + creds = static_provider.get_credentials(readonly=True) + assert creds["host"] == CONFIG.database.server + assert creds["user"] == CONFIG.database.user + + def test_readonly_uses_readonly_fields(self): + readonly_settings = DatabaseSettings( + readonly_server="replica", + readonly_port="5433", + readonly_user="ro_user", + readonly_password="ro_pass", + readonly_db="ro_db", + ) + with ( + patch("fides.config.secrets.static_provider.CONFIG") as mock_sp_config, + patch("fides.common.db_credential_provider.CONFIG") as mock_dcp_config, + patch( + "fides.common.db_credential_provider.get_secret_provider" + ) as mock_get, + ): + mock_sp_config.database = readonly_settings + mock_dcp_config.database = readonly_settings + mock_dcp_config.test_mode = False + mock_get.return_value = StaticSecretProvider() + + provider = DBCredentialProvider() + creds = provider.get_credentials(readonly=True) + assert creds == { + "host": "replica", + "port": 5433, + "user": "ro_user", + "password": "ro_pass", + "dbname": "ro_db", + } + + def test_dynamic_fetches_user_password_from_secret(self, dynamic_provider): + provider, _, mock_secret_provider, _ = dynamic_provider + assert provider.is_dynamic + creds = provider.get_credentials() + assert creds["user"] == "secret_user" + assert creds["password"] == "secret_pass" + assert creds["host"] == CONFIG.database.server + mock_secret_provider.get_secret.assert_called_once_with("db-creds") + + def test_dynamic_without_credential_secret_name_raises(self, dynamic_provider): + provider, mock_config, _, _ = dynamic_provider + mock_config.database.credential_secret_name = None + with pytest.raises(ValueError, match="credential_secret_name is not set"): + provider.get_credentials() + + def test_dynamic_readonly_falls_back_to_primary_secret_id(self, dynamic_provider): + provider, _, mock_secret_provider, _ = dynamic_provider + provider.get_credentials(readonly=True) + mock_secret_provider.get_secret.assert_called_once_with("db-creds") + + +# --- Database URL construction --- + + +class TestGetDatabaseUrl: + def test_returns_valid_url(self, static_provider): + url = static_provider.get_database_url() + assert url.startswith("postgresql+psycopg2://") + assert CONFIG.database.server in url + + def test_includes_connection_params(self): + """SSL and other params from CONFIG.database.params are appended as query params.""" + with ( + patch("fides.config.secrets.static_provider.CONFIG") as mock_sp_config, + patch("fides.common.db_credential_provider.CONFIG") as mock_dcp_config, + patch( + "fides.common.db_credential_provider.get_secret_provider" + ) as mock_get, + ): + mock_sp_config.database = DatabaseSettings( + params={"sslmode": "require", "sslrootcert": "/path/to/cert"}, + ) + mock_dcp_config.database = mock_sp_config.database + mock_dcp_config.test_mode = False + mock_get.return_value = StaticSecretProvider() + + provider = DBCredentialProvider() + url = provider.get_database_url() + assert "sslmode=require" in url + assert "sslrootcert=/path/to/cert" in url + + @pytest.mark.parametrize( + "user,password", + [ + ("user@domain", "p@ss"), + ("user", "pass%word"), + ("user", "pass/word"), + ("user", "pass#word"), + ("user", "p@ss#w%rd/123"), + ], + ids=["at-sign", "percent", "slash", "hash", "mixed-special"], + ) + def test_special_characters_are_url_encoded(self, user, password): + """Credentials with special characters must be URL-encoded so the + resulting URL is parseable by SQLAlchemy / libpq.""" + with ( + patch("fides.config.secrets.static_provider.CONFIG") as mock_sp_config, + patch("fides.common.db_credential_provider.CONFIG") as mock_dcp_config, + patch( + "fides.common.db_credential_provider.get_secret_provider" + ) as mock_get, + ): + mock_sp_config.database = DatabaseSettings(user=user, password=password) + mock_dcp_config.database = mock_sp_config.database + mock_dcp_config.test_mode = False + mock_get.return_value = StaticSecretProvider() + + provider = DBCredentialProvider() + url = provider.get_database_url() + + # Raw special chars should not appear unescaped in the URL + # (the user:password section is between :// and @) + user_pass_section = url.split("://")[1].split("@")[0] + assert ( + "@" not in user_pass_section.split(":")[0] or "%40" in user_pass_section + ) + assert "#" not in user_pass_section + assert "/" not in user_pass_section + + +# --- Connection retry --- + + +class TestConnectWithRetry: + def test_happy_path_merges_connect_kwargs_and_credentials(self, static_provider): + connect_fn = MagicMock(return_value="connection") + result = static_provider.connect_with_retry(connect_fn, {"keepalives": 1}) + assert result == "connection" + call_kwargs = connect_fn.call_args[1] + assert call_kwargs["host"] == CONFIG.database.server + assert call_kwargs["keepalives"] == 1 + + def test_non_auth_error_raises_sanitized(self, static_provider): + connect_fn = MagicMock(side_effect=RuntimeError("connection refused")) + with pytest.raises(SanitizedConnectionError) as exc_info: + static_provider.connect_with_retry(connect_fn, {}) + assert CONFIG.database.raw_password not in str(exc_info.value) + assert exc_info.value.__cause__ is None + + def test_static_does_not_retry_auth_error(self, static_provider): + connect_fn = MagicMock(side_effect=_make_auth_error(pgcode="28P01")) + with pytest.raises(SanitizedConnectionError): + static_provider.connect_with_retry(connect_fn, {}) + connect_fn.assert_called_once() + + @pytest.mark.parametrize( + "exc", + [ + _make_auth_error(pgcode="28P01"), + _make_auth_error(pgcode="28000"), + _make_auth_error(sqlstate="28P01"), + _make_auth_error(sqlstate="28000"), + ], + ids=["pgcode-28P01", "pgcode-28000", "sqlstate-28P01", "sqlstate-28000"], + ) + def test_dynamic_retries_on_auth_error_and_succeeds(self, dynamic_provider, exc): + provider, _, mock_secret_provider, mock_time = dynamic_provider + connect_fn = MagicMock(side_effect=[exc, "connection"]) + result = provider.connect_with_retry(connect_fn, {}) + + assert result == "connection" + assert connect_fn.call_count == 2 + mock_secret_provider.invalidate.assert_called_once_with("db-creds") + mock_time.sleep.assert_called_once_with(_AUTH_RETRY_DELAY) + + def test_dynamic_retry_fails_raises_sanitized(self, dynamic_provider): + provider, _, _, _ = dynamic_provider + connect_fn = MagicMock( + side_effect=[ + _make_auth_error(pgcode="28P01"), + _make_auth_error(pgcode="28P01"), + ] + ) + with pytest.raises(SanitizedConnectionError) as exc_info: + provider.connect_with_retry(connect_fn, {}) + assert exc_info.value.__cause__ is None + assert connect_fn.call_count == 2 + + def test_psycopg2_wrong_password_against_real_db(self): + """ + End-to-end: attempts to connect to the test database and + checks the real psycopg2 auth failure is caught and sanitized. + """ + wrong_password = "definitely-wrong-password-xyz" + db = CONFIG.database + wrong_db = DatabaseSettings( + server=db.server, + port=db.port, + user=db.user, + password=wrong_password, + db=db.db, + test_db=db.test_db, + ) + with ( + patch("fides.config.secrets.static_provider.CONFIG") as mock_sp_config, + patch("fides.common.db_credential_provider.CONFIG") as mock_dcp_config, + patch( + "fides.common.db_credential_provider.get_secret_provider" + ) as mock_get, + ): + mock_sp_config.database = wrong_db + mock_dcp_config.database = wrong_db + mock_dcp_config.test_mode = CONFIG.test_mode + mock_get.return_value = StaticSecretProvider() + + provider = DBCredentialProvider() + with pytest.raises(SanitizedConnectionError) as exc_info: + provider.connect_with_retry(psycopg2.connect, {}) + assert wrong_password not in str(exc_info.value) + assert exc_info.value.__cause__ is None + + async def test_asyncpg_wrong_password_against_real_db(self): + """ + End-to-end: attempts to connect to the test database and + checks the real asyncpg auth failure is caught and sanitized. + """ + _dbapi = AsyncAdapt_asyncpg_dbapi(asyncpg) + wrong_password = "definitely-wrong-password-xyz" + db = CONFIG.database + wrong_db = DatabaseSettings( + server=db.server, + port=db.port, + user=db.user, + password=wrong_password, + db=db.db, + test_db=db.test_db, + ) + + def _connect_asyncpg(**kwargs): + kw = { + "host": kwargs.pop("host"), + "port": kwargs.pop("port"), + "user": kwargs.pop("user"), + "password": kwargs.pop("password"), + "database": kwargs.pop("dbname"), + } + kw.update(kwargs) + raw_conn = await_only(asyncpg.connect(**kw)) + return AsyncAdapt_asyncpg_connection(_dbapi, raw_conn) + + with ( + patch("fides.config.secrets.static_provider.CONFIG") as mock_sp_config, + patch("fides.common.db_credential_provider.CONFIG") as mock_dcp_config, + patch( + "fides.common.db_credential_provider.get_secret_provider" + ) as mock_get, + ): + mock_sp_config.database = wrong_db + mock_dcp_config.database = wrong_db + mock_dcp_config.test_mode = CONFIG.test_mode + mock_get.return_value = StaticSecretProvider() + + provider = DBCredentialProvider() + with pytest.raises(SanitizedConnectionError) as exc_info: + provider.connect_with_retry(_connect_asyncpg, {}) + assert wrong_password not in str(exc_info.value) + assert exc_info.value.__cause__ is None + + +# --- Credential leakage --- + + +class TestCredentialLeakage: + def test_password_not_in_sanitized_exception(self, static_provider): + password = "v3ry-s3cret-p@ssw0rd!" + connect_fn = MagicMock(side_effect=_make_password_leaking_error(password)) + with pytest.raises(SanitizedConnectionError) as exc_info: + static_provider.connect_with_retry(connect_fn, {}) + assert password not in str(exc_info.value) + assert password not in repr(exc_info.value) + + def test_password_not_in_log_output_during_retry(self, dynamic_provider, caplog): + provider, _, mock_secret_provider, _ = dynamic_provider + password = "super-secret-123" + mock_secret_provider.get_secret.return_value = SecretValue( + {"username": "u", "password": password} + ) + connect_fn = MagicMock( + side_effect=[ + _make_auth_error(pgcode="28P01"), + _make_auth_error(pgcode="28P01"), + ] + ) + with pytest.raises(SanitizedConnectionError): + provider.connect_with_retry(connect_fn, {}) + assert password not in caplog.text diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/config/secrets/__init__.py b/tests/config/secrets/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/config/secrets/test_aws_secrets_manager_provider.py b/tests/config/secrets/test_aws_secrets_manager_provider.py new file mode 100644 index 00000000000..2ed60cfdd6b --- /dev/null +++ b/tests/config/secrets/test_aws_secrets_manager_provider.py @@ -0,0 +1,564 @@ +import json +import threading +import time +from unittest.mock import MagicMock, patch + +import boto3 +import pytest +from botocore.exceptions import ClientError +from moto import mock_aws + +from fides.config.secrets.aws_secrets_manager_provider import ( + AWSSecretsManagerProvider, +) +from fides.config.secrets.base import SecretProviderError, SecretValue + +REGION = "us-east-1" +SECRET_NAME = "test/db-creds" +SECRET_DATA = {"username": "testuser", "password": "testpass"} + + +@pytest.fixture() +def aws_env(): + """Set up a moto mock with a pre-created secret.""" + with mock_aws(): + client = boto3.client("secretsmanager", region_name=REGION) + client.create_secret( + Name=SECRET_NAME, + SecretString=json.dumps(SECRET_DATA), + ) + yield client + + +class TestBasicFetch: + def test_get_secret_fetches_from_aws(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + assert secret["password"] == "testpass" + + def test_get_secret_returns_secret_value_type(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + secret = provider.get_secret(SECRET_NAME) + assert isinstance(secret, SecretValue) + + def test_get_secret_contains_expected_fields(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + secret = provider.get_secret(SECRET_NAME) + assert "username" in secret + assert "password" in secret + + def test_unknown_secret_raises(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + with pytest.raises(SecretProviderError, match="permanently inaccessible"): + provider.get_secret("nonexistent-secret") + + def test_invalid_json_raises_without_leaking_secret(self, aws_env): + """If the secret value is not valid JSON, the error must not + contain the raw secret string (which could be a plain password).""" + raw_password = "super-secret-p@ssw0rd!" + aws_env.update_secret( + SecretId=SECRET_NAME, + SecretString=raw_password, + ) + provider = AWSSecretsManagerProvider(region_name=REGION) + with pytest.raises(SecretProviderError) as exc_info: + provider.get_secret(SECRET_NAME) + + # Walk the full exception chain and verify the password + # doesn't appear anywhere + exc = exc_info.value + while exc is not None: + assert raw_password not in str(exc) + # Also check JSONDecodeError's .doc attribute if present + if hasattr(exc, "doc"): + raise AssertionError("JSONDecodeError with .doc leaked into chain") + exc = exc.__cause__ + + def test_binary_secret_raises_clear_error(self, aws_env): + """Secrets stored as SecretBinary should raise a clear error + rather than a confusing KeyError.""" + aws_env.delete_secret(SecretId=SECRET_NAME, ForceDeleteWithoutRecovery=True) + aws_env.create_secret( + Name=SECRET_NAME, + SecretBinary=b"\x00\x01\x02binary-data", + ) + provider = AWSSecretsManagerProvider(region_name=REGION) + with pytest.raises(SecretProviderError, match="stored as binary"): + provider.get_secret(SECRET_NAME) + + +class TestCaching: + def test_second_call_uses_cache(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION, cache_ttl_seconds=60.0) + v1 = provider.get_secret(SECRET_NAME) + + # Sabotage the client — if cache works, this won't be called + provider._client.get_secret_value = MagicMock( + side_effect=Exception("should not be called") + ) + v2 = provider.get_secret(SECRET_NAME) + assert v2["username"] == v1["username"] + + def test_cache_expires_after_ttl(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, cache_ttl_seconds=10.0 + ) + provider.get_secret(SECRET_NAME) + + # Advance past TTL + time_value[0] = 200.0 + + # Update the secret so we can detect a re-fetch + aws_env.update_secret( + SecretId=SECRET_NAME, + SecretString=json.dumps({"username": "rotated", "password": "newpass"}), + ) + v2 = provider.get_secret(SECRET_NAME) + assert v2["username"] == "rotated" + + def test_cache_hit_within_ttl(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, cache_ttl_seconds=60.0 + ) + provider.get_secret(SECRET_NAME) + + # Advance time but stay within TTL + time_value[0] = 130.0 + + # Update secret — should NOT be picked up + aws_env.update_secret( + SecretId=SECRET_NAME, + SecretString=json.dumps({"username": "rotated", "password": "newpass"}), + ) + v2 = provider.get_secret(SECRET_NAME) + assert v2["username"] == "testuser" + + +class TestStaleWhileRevalidate: + def test_stale_value_served_on_fetch_failure(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + ) + provider.get_secret(SECRET_NAME) + + # Advance past TTL but within stale window + time_value[0] = 120.0 + + # Make fetches fail + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + + def test_stale_value_served_after_invalidation_and_fetch_failure(self, aws_env): + """When fetched_at=0 (post-invalidation) and the fetch fails, the + provider should serve the stale cached value rather than raising.""" + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=60.0, + ) + provider.get_secret(SECRET_NAME) + + # Invalidate (sets fetched_at=0), then make fetches fail + provider.invalidate(SECRET_NAME) + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + + # Should serve stale value, not raise + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + + def test_hard_failure_after_stale_ttl(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=60.0, + ) + provider.get_secret(SECRET_NAME) + + # Advance past TTL + stale TTL + time_value[0] = 200.0 + + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + with pytest.raises(SecretProviderError, match="stale cache has expired"): + provider.get_secret(SECRET_NAME) + + def test_no_cached_value_fetch_failure_raises(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + with pytest.raises(SecretProviderError, match="no cached value available"): + provider.get_secret(SECRET_NAME) + + +class TestPermanentErrors: + """Permanent AWS errors (secret deleted, IAM revoked, etc.) should fail + immediately and clear the cache, not serve stale credentials.""" + + @pytest.mark.parametrize( + "error_code", + [ + "ResourceNotFoundException", + "AccessDeniedException", + "DecryptionFailureException", + "InvalidRequestException", + ], + ) + def test_permanent_error_raises_immediately(self, aws_env, error_code): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + ) + provider.get_secret(SECRET_NAME) + + # Expire TTL, then simulate permanent error + time_value[0] = 120.0 + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": error_code, "Message": "revoked"}}, + "GetSecretValue", + ) + ) + + with pytest.raises(SecretProviderError, match="permanently inaccessible"): + provider.get_secret(SECRET_NAME) + + def test_permanent_error_clears_cache(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + ) + provider.get_secret(SECRET_NAME) + + # Expire TTL, simulate permanent error + time_value[0] = 120.0 + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "revoked"}}, + "GetSecretValue", + ) + ) + + with pytest.raises(SecretProviderError): + provider.get_secret(SECRET_NAME) + + # Cache should be cleared — switch to a transient error to + # verify the cached value is gone (not just re-triggering + # the permanent error path) + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + with pytest.raises(SecretProviderError, match="no cached value available"): + provider.get_secret(SECRET_NAME) + + def test_transient_error_still_serves_stale(self, aws_env): + """InternalServiceError is transient — should still serve stale.""" + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + ) + provider.get_secret(SECRET_NAME) + + time_value[0] = 120.0 + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + + +class TestInvalidate: + def test_invalidate_forces_refetch(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, cache_ttl_seconds=300.0 + ) + provider.get_secret(SECRET_NAME) + + # Update secret in SM + aws_env.update_secret( + SecretId=SECRET_NAME, + SecretString=json.dumps({"username": "rotated", "password": "newpass"}), + ) + + # Without invalidation, still cached + assert provider.get_secret(SECRET_NAME)["username"] == "testuser" + + # Invalidate and re-fetch + provider.invalidate(SECRET_NAME) + assert provider.get_secret(SECRET_NAME)["username"] == "rotated" + + def test_invalidate_unknown_secret_is_noop(self, aws_env): + provider = AWSSecretsManagerProvider(region_name=REGION) + provider.invalidate("nonexistent") # should not raise + + +class TestCircuitBreaker: + def test_invalidate_within_cooldown_is_noop(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + circuit_breaker_cooldown_seconds=30.0, + ) + provider.get_secret(SECRET_NAME) + + # Expire TTL + time_value[0] = 120.0 + + # Make fetches fail — triggers last_failed_at + call_count = 0 + + def counting_fail(**kwargs): + nonlocal call_count + call_count += 1 + raise ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + + provider._client.get_secret_value = counting_fail + + # First fetch after expiry — fails, serves stale, sets last_failed_at + provider.get_secret(SECRET_NAME) + assert call_count == 1 + + # Invalidate within cooldown — should be a no-op + provider.invalidate(SECRET_NAME) + + # Next get_secret should NOT re-fetch (circuit breaker active) + calls_before = call_count + provider.get_secret(SECRET_NAME) + assert call_count == calls_before + + def test_invalidate_after_cooldown_resets(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + circuit_breaker_cooldown_seconds=30.0, + ) + provider.get_secret(SECRET_NAME) + + # Expire TTL, sabotage client + time_value[0] = 120.0 + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + provider.get_secret(SECRET_NAME) # sets last_failed_at + + # Advance past cooldown + time_value[0] = 160.0 + + # Restore client + provider._client = boto3.client("secretsmanager", region_name=REGION) + + # Invalidate now works (past cooldown) + provider.invalidate(SECRET_NAME) + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + + def test_circuit_resets_on_successful_ttl_refresh(self, aws_env): + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, + cache_ttl_seconds=10.0, + cache_stale_ttl_seconds=1800.0, + circuit_breaker_cooldown_seconds=30.0, + ) + provider.get_secret(SECRET_NAME) + + # Expire TTL, fail once to set last_failed_at + time_value[0] = 120.0 + provider._client.get_secret_value = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "InternalServiceError", "Message": "boom"}}, + "GetSecretValue", + ) + ) + provider.get_secret(SECRET_NAME) + + # Restore client, advance past both cooldown and TTL + provider._client = boto3.client("secretsmanager", region_name=REGION) + time_value[0] = 155.0 # past cooldown (120 + 30) and TTL expired + + # Successful fetch should clear last_failed_at + secret = provider.get_secret(SECRET_NAME) + assert secret["username"] == "testuser" + + # Verify circuit is reset — check the entry directly + entry = provider._cache[SECRET_NAME] + assert entry.last_failed_at == 0.0 + + +class TestThreadSafety: + def test_concurrent_access_single_fetch(self, aws_env): + """Multiple threads hitting get_secret after invalidation should + result in only 1 actual Secrets Manager call.""" + time_value = [100.0] + + with patch( + "fides.config.secrets.aws_secrets_manager_provider.time" + ) as mock_time: + mock_time.monotonic = lambda: time_value[0] + + provider = AWSSecretsManagerProvider( + region_name=REGION, cache_ttl_seconds=300.0 + ) + provider.get_secret(SECRET_NAME) + + # Invalidate to force refetch + provider.invalidate(SECRET_NAME) + + # Wrap client to count calls with some latency + real_get = provider._client.get_secret_value + call_count = 0 + call_lock = threading.Lock() + + def counting_get(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + # Simulate latency to widen the race window + time.sleep(0.05) + return real_get(**kwargs) + + provider._client.get_secret_value = counting_get + + # Launch threads + num_threads = 10 + barrier = threading.Barrier(num_threads) + results = [None] * num_threads + + def worker(idx): + barrier.wait() + results[idx] = provider.get_secret(SECRET_NAME) + + threads = [ + threading.Thread(target=worker, args=(i,)) for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + # Only 1 thread should have actually fetched + assert call_count == 1, f"Expected 1 fetch, got {call_count}" + # All threads got correct results + assert all(r["username"] == "testuser" for r in results) diff --git a/tests/config/secrets/test_factory.py b/tests/config/secrets/test_factory.py new file mode 100644 index 00000000000..77c7d17446a --- /dev/null +++ b/tests/config/secrets/test_factory.py @@ -0,0 +1,120 @@ +import os +from unittest.mock import patch + +import pytest +from moto import mock_aws +from pydantic import ValidationError + +from fides.config.secrets import get_secret_provider, reset_secret_provider +from fides.config.secrets.aws_secrets_manager_provider import ( + AWSSecretsManagerProvider, +) +from fides.config.secrets.base import SecretProviderError, SecretValue +from fides.config.secrets.factory import create_secret_provider +from fides.config.secrets.static_provider import StaticSecretProvider +from fides.config.secrets_settings import AWSSecretsManagerSettings, SecretsSettings + + +class TestSecretsSettings: + def test_aws_provider_without_explicit_config_uses_defaults(self): + """When provider is aws but no config section provided, + settings are constructed with defaults (region from boto3 chain).""" + settings = SecretsSettings(provider="aws_secrets_manager") + assert settings.aws_secrets_manager is not None + assert settings.aws_secrets_manager.region is None + + def test_aws_provider_with_aws_config_passes(self): + settings = SecretsSettings( + provider="aws_secrets_manager", + aws_secrets_manager={"region": "us-east-1"}, + ) + assert settings.aws_secrets_manager is not None + assert settings.aws_secrets_manager.region == "us-east-1" + + def test_aws_provider_from_env_vars(self): + with patch.dict( + os.environ, + { + "FIDES__SECRETS__AWS_SECRETS_MANAGER__REGION": "eu-west-1", + }, + ): + settings = SecretsSettings(provider="aws_secrets_manager") + assert settings.aws_secrets_manager.region == "eu-west-1" + + def test_aws_config_region_defaults_to_none(self): + settings = AWSSecretsManagerSettings() + assert settings.region is None + + def test_static_provider_without_aws_config_passes(self): + settings = SecretsSettings(provider="static") + assert settings.aws_secrets_manager is None + + +class TestCreateSecretProvider: + def test_static_provider(self): + settings = SecretsSettings(provider="static") + provider = create_secret_provider(settings) + assert isinstance(provider, StaticSecretProvider) + + @mock_aws + def test_aws_secrets_manager_provider(self): + settings = SecretsSettings( + provider="aws_secrets_manager", + aws_secrets_manager={ + "region": "us-east-1", + "cache_ttl_seconds": 120.0, + }, + ) + provider = create_secret_provider(settings) + assert isinstance(provider, AWSSecretsManagerProvider) + assert provider._cache_ttl == 120.0 + + def test_aws_provider_with_missing_aws_config_raises_in_factory(self): + """Bypass Pydantic validation to test the factory's own guard.""" + settings = SecretsSettings() + settings.provider = "aws_secrets_manager" # type: ignore[assignment] + settings.aws_secrets_manager = None + with pytest.raises( + SecretProviderError, match="aws_secrets_manager is not configured" + ): + create_secret_provider(settings) + + def test_unknown_provider_raises_at_validation(self): + with pytest.raises(ValidationError, match="literal_error"): + SecretsSettings(provider="vault") + + def test_unknown_provider_raises_in_factory(self): + """Bypass Pydantic validation to test the factory's own guard.""" + settings = SecretsSettings() + settings.provider = "unknown" # type: ignore[assignment] + with pytest.raises(SecretProviderError, match="Unknown secrets provider"): + create_secret_provider(settings) + + def test_default_provider_is_static(self): + settings = SecretsSettings() + assert settings.provider == "static" + provider = create_secret_provider(settings) + assert isinstance(provider, StaticSecretProvider) + + +class TestGetSecretProvider: + def setup_method(self): + reset_secret_provider() + + def teardown_method(self): + reset_secret_provider() + + def test_returns_provider(self): + provider = get_secret_provider() + assert isinstance(provider, StaticSecretProvider) + + def test_returns_same_instance(self): + first = get_secret_provider() + second = get_secret_provider() + assert first is second + + def test_reset_forces_new_instance(self): + first = get_secret_provider() + reset_secret_provider() + second = get_secret_provider() + assert first is not second diff --git a/tests/config/secrets/test_secret_value.py b/tests/config/secrets/test_secret_value.py new file mode 100644 index 00000000000..69f9941c3a3 --- /dev/null +++ b/tests/config/secrets/test_secret_value.py @@ -0,0 +1,90 @@ +import pytest + +from fides.config.secrets.base import SecretValue + + +class TestSecretValue: + def test_subscript_access(self): + sv = SecretValue({"username": "admin", "password": "s3cret"}) + assert sv["username"] == "admin" + assert sv["password"] == "s3cret" + + def test_missing_key_raises_key_error(self): + sv = SecretValue({"username": "admin"}) + with pytest.raises(KeyError): + _ = sv["nonexistent"] + + def test_contains(self): + sv = SecretValue({"username": "admin"}) + assert "username" in sv + assert "missing" not in sv + + def test_repr_is_redacted(self): + sv = SecretValue({"password": "super-secret"}) + assert repr(sv) == "" + + def test_str_is_redacted(self): + sv = SecretValue({"password": "super-secret"}) + assert str(sv) == "" + + def test_fstring_is_redacted(self): + sv = SecretValue({"password": "super-secret"}) + assert f"value={sv}" == "value=" + + def test_equality(self): + a = SecretValue({"k": "v"}) + b = SecretValue({"k": "v"}) + assert a == b + + def test_inequality(self): + a = SecretValue({"k": "v1"}) + b = SecretValue({"k": "v2"}) + assert a != b + + def test_equality_with_non_secret_value(self): + sv = SecretValue({"k": "v"}) + assert sv != {"k": "v"} + + def test_contains_checks_fields(self): + sv = SecretValue({"username": "admin", "password": "s3cret"}) + assert "username" in sv + assert "password" in sv + assert "other" not in sv + + def test_dict_conversion_blocked(self): + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError): + dict(sv) + + def test_unpacking_blocked(self): + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError): + {**sv} + + def test_vars_blocked(self): + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError): + vars(sv) + + def test_no_dict(self): + sv = SecretValue({"password": "s3cret"}) + assert not hasattr(sv, "__dict__") + + def test_pickle_blocked(self): + import pickle # local import — only used in this test + + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError, match="cannot be pickled"): + pickle.dumps(sv) + + def test_getstate_blocked(self): + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError, match="cannot be serialized"): + sv.__getstate__() + + def test_copy_blocked(self): + import copy # local import — only used in this test + + sv = SecretValue({"password": "s3cret"}) + with pytest.raises(TypeError): + copy.copy(sv) diff --git a/tests/config/secrets/test_static_provider.py b/tests/config/secrets/test_static_provider.py new file mode 100644 index 00000000000..2ea19bb3efd --- /dev/null +++ b/tests/config/secrets/test_static_provider.py @@ -0,0 +1,42 @@ +import pytest + +from fides.config import CONFIG +from fides.config.secrets.base import SecretProviderError, SecretValue +from fides.config.secrets.static_provider import ( + DATABASE_CREDENTIALS_KEY, + DATABASE_READONLY_CREDENTIALS_KEY, + StaticSecretProvider, +) + + +class TestStaticSecretProvider: + def test_caches_db_credentials_from_config(self): + provider = StaticSecretProvider() + secret = provider.get_secret(DATABASE_CREDENTIALS_KEY) + assert isinstance(secret, SecretValue) + assert secret["username"] == CONFIG.database.user + assert secret["password"] == CONFIG.database.raw_password + + def test_unknown_id_raises(self): + provider = StaticSecretProvider() + with pytest.raises(SecretProviderError, match="Unknown secret_id"): + provider.get_secret("nonexistent") + + def test_invalidate_is_noop(self): + provider = StaticSecretProvider() + provider.invalidate(DATABASE_CREDENTIALS_KEY) + assert ( + provider.get_secret(DATABASE_CREDENTIALS_KEY)["username"] + == CONFIG.database.user + ) + + def test_invalidate_unknown_id_does_not_raise(self): + provider = StaticSecretProvider() + provider.invalidate("nonexistent") + + def test_readonly_credentials_absent_when_no_readonly_server(self): + if CONFIG.database.readonly_server: + pytest.skip("readonly_server is configured in this environment") + provider = StaticSecretProvider() + with pytest.raises(SecretProviderError, match="Unknown secret_id"): + provider.get_secret(DATABASE_READONLY_CREDENTIALS_KEY) diff --git a/tests/ctl/core/config/test_config.py b/tests/ctl/core/config/test_config.py index 3e18474e1a7..21826b19d2a 100644 --- a/tests/ctl/core/config/test_config.py +++ b/tests/ctl/core/config/test_config.py @@ -7,7 +7,11 @@ from pydantic import ValidationError from fides.api.db.database import get_alembic_config -from fides.config import check_required_webserver_config_values, get_config +from fides.config import ( + build_config, + check_required_webserver_config_values, + get_config, +) from fides.config.database_settings import DatabaseSettings from fides.config.redis_settings import RedisSettings from fides.config.security_settings import SecuritySettings @@ -67,6 +71,7 @@ def test_get_config_default() -> None: """Check that get_config loads default values when given an empty TOML.""" config = get_config() assert config.database.api_engine_pool_size == 50 + assert config.database.pool_recycle is None assert config.security.env == "prod" assert config.security.app_encryption_key == "" assert config.logging.level == "INFO" @@ -242,6 +247,28 @@ def test_get_alembic_config_with_special_char_in_database_url(): get_alembic_config(database_url) +@pytest.mark.unit +def test_database_settings_pool_recycle_defaults_to_none() -> None: + """pool_recycle is optional and defaults to None.""" + db_settings = DatabaseSettings() + assert db_settings.pool_recycle is None + + +@pytest.mark.unit +def test_database_settings_pool_recycle_accepts_positive() -> None: + """pool_recycle accepts a positive integer.""" + db_settings = DatabaseSettings(pool_recycle=1800) + assert db_settings.pool_recycle == 1800 + + +@pytest.mark.unit +@pytest.mark.parametrize("value", [0, -1, -5], ids=["zero", "neg_one", "neg_five"]) +def test_database_settings_pool_recycle_rejects_invalid(value: int) -> None: + """pool_recycle must be > 0 when set.""" + with pytest.raises(ValidationError): + DatabaseSettings(pool_recycle=value) + + @pytest.mark.unit def test_database_settings_migration_role_defaults_to_none() -> None: """migration_role is optional and defaults to None.""" @@ -778,3 +805,69 @@ def test_readonly_async_uri_ssl_handling(self): assert "ssl=" in parsed.query # sslrootcert should be removed from query params assert "sslrootcert" not in parsed.query + + +@pytest.mark.unit +class TestDatabaseCredentialSecretIdValidation: + """Validate cross-section coherence between secrets.provider and database.credential_secret_name.""" + + def test_static_provider_with_credential_secret_name_raises(self) -> None: + with pytest.raises(ValidationError) as exc: + build_config( + { + "secrets": {"provider": "static"}, + "database": { + "credential_secret_name": "arn:aws:secretsmanager:us-east-1:123:secret:db-creds" + }, + } + ) + assert "credential_secret_name" in str(exc.value) + assert "static" in str(exc.value) + + def test_static_provider_with_readonly_credential_secret_name_raises(self) -> None: + with pytest.raises(ValidationError) as exc: + build_config( + { + "secrets": {"provider": "static"}, + "database": { + "readonly_credential_secret_name": "arn:aws:secretsmanager:us-east-1:123:secret:ro-creds" + }, + } + ) + assert "readonly_credential_secret_name" in str(exc.value) + assert "static" in str(exc.value) + + def test_aws_provider_without_credential_secret_name_raises(self) -> None: + with pytest.raises(ValidationError) as exc: + build_config( + { + "secrets": { + "provider": "aws_secrets_manager", + "aws_secrets_manager": {"region": "us-east-1"}, + }, + } + ) + assert "credential_secret_name is not set" in str(exc.value) + + def test_aws_provider_with_credential_secret_name_passes(self) -> None: + config = build_config( + { + "secrets": { + "provider": "aws_secrets_manager", + "aws_secrets_manager": {"region": "us-east-1"}, + }, + "database": { + "credential_secret_name": "arn:aws:secretsmanager:us-east-1:123:secret:db-creds" + }, + } + ) + assert ( + config.database.credential_secret_name + == "arn:aws:secretsmanager:us-east-1:123:secret:db-creds" + ) + assert config.database.readonly_credential_secret_name is None + + def test_static_provider_without_secret_ids_passes(self) -> None: + config = build_config({}) + assert config.database.credential_secret_name is None + assert config.database.readonly_credential_secret_name is None diff --git a/tests/lib/test_engine_creators.py b/tests/lib/test_engine_creators.py new file mode 100644 index 00000000000..3f715555e21 --- /dev/null +++ b/tests/lib/test_engine_creators.py @@ -0,0 +1,215 @@ +"""Tests for engine creator factories and helpers.""" + +import datetime +import ssl +from unittest.mock import MagicMock, patch + +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from fides.common.engine_creators import ( + ASYNC_DIALECT_URL, + SYNC_DIALECT_URL, + _build_ssl_context, + _convert_asyncpg_params, + make_async_creator, + make_sync_creator, +) +from fides.config import CONFIG +from fides.config.database_settings import DatabaseSettings + + +@pytest.fixture() +def self_signed_cert(tmp_path): + """Generate a self-signed CA cert and return the file path.""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test-ca")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after( + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) + ) + .sign(key, hashes.SHA256()) + ) + cert_file = tmp_path / "ca.pem" + cert_file.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + return str(cert_file) + + +class TestRawPassword: + """Verify raw_password round-trips passwords with special characters.""" + + @pytest.mark.parametrize( + "password", + [ + "simple", + "p@ssw0rd", + "pass#word", + "pass%word", + "pass/word", + "p@ss#w%rd/123", + "has spaces", + "has+plus", + ], + ) + def test_raw_password_round_trip(self, password: str) -> None: + """Constructing DatabaseSettings with special-char passwords + should produce a raw_password that matches the original.""" + settings = DatabaseSettings(password=password) + assert settings.raw_password == password + + @pytest.mark.parametrize("password", ["p@ssw0rd", "pass%word"]) + def test_raw_readonly_password_round_trip(self, password: str) -> None: + """readonly_password should also round-trip through quote_plus.""" + settings = DatabaseSettings( + readonly_server="replica", readonly_password=password + ) + assert settings.raw_readonly_password == password + + def test_raw_readonly_password_none_when_not_set(self) -> None: + settings = DatabaseSettings() + assert settings.raw_readonly_password is None + + def test_pre_encoded_password_treated_as_literal(self) -> None: + """Passwords are always treated as raw values, never as pre-encoded. + + If a user sets their password to "foo%40bar" (literally containing + the characters %, 4, 0), raw_password returns "foo%40bar" — NOT + "foo@bar". This matches the old URI-based path where escape_password + would double-encode %40 to %2540 in the URI, and psycopg2 would + decode it back to %40. + """ + settings = DatabaseSettings(password="foo%40bar") + assert settings.raw_password == "foo%40bar" + + +class TestConvertAsyncpgParams: + def test_converts_sslmode_to_ssl(self) -> None: + params = {"sslmode": "require", "other": "value"} + result = _convert_asyncpg_params(params) + assert "sslmode" not in result + assert result["ssl"] == "require" + assert result["other"] == "value" + + def test_drops_sslrootcert(self) -> None: + params = {"sslrootcert": "/path/to/cert.pem", "other": "value"} + result = _convert_asyncpg_params(params) + assert "sslrootcert" not in result + assert result["other"] == "value" + + def test_does_not_mutate_input(self) -> None: + params = {"sslmode": "require", "sslrootcert": "/path"} + _convert_asyncpg_params(params) + assert "sslmode" in params + assert "sslrootcert" in params + + def test_empty_params(self) -> None: + assert _convert_asyncpg_params({}) == {} + + +class TestBuildSslContext: + def test_returns_none_without_sslrootcert(self) -> None: + assert _build_ssl_context({}) is None + assert _build_ssl_context({"sslmode": "require"}) is None + + def test_returns_context_with_valid_sslrootcert(self, self_signed_cert) -> None: + """Success path: a valid CA cert produces a usable SSLContext.""" + ctx = _build_ssl_context({"sslrootcert": self_signed_cert}) + assert isinstance(ctx, ssl.SSLContext) + assert ctx.verify_mode == ssl.CERT_REQUIRED + + def test_returns_none_with_invalid_cert(self, tmp_path) -> None: + cert_file = tmp_path / "bad.pem" + cert_file.write_text("not a cert") + with pytest.raises(ssl.SSLError): + _build_ssl_context({"sslrootcert": str(cert_file)}) + + +class TestMakeSyncCreator: + def test_returns_callable(self) -> None: + creator = make_sync_creator() + assert callable(creator) + + def test_creator_opens_working_connection(self) -> None: + """The sync creator should produce a real psycopg2 connection.""" + creator = make_sync_creator() + conn = creator() + try: + cur = conn.cursor() + cur.execute("SELECT 1") + assert cur.fetchone() == (1,) + finally: + conn.close() + + def test_creator_with_connect_args(self) -> None: + """connect_args like keepalives should be forwarded.""" + creator = make_sync_creator( + connect_args={"keepalives": 1, "keepalives_idle": 30} + ) + conn = creator() + try: + cur = conn.cursor() + cur.execute("SELECT 1") + assert cur.fetchone() == (1,) + finally: + conn.close() + + def test_engine_with_sync_creator(self) -> None: + """A full engine using the sync creator can execute queries.""" + creator = make_sync_creator() + engine = create_engine(SYNC_DIALECT_URL, creator=creator, pool_size=1) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + +class TestMakeAsyncCreator: + def test_returns_callable(self) -> None: + creator = make_async_creator() + assert callable(creator) + + async def test_engine_with_async_creator(self) -> None: + """A full async engine using the async creator can execute queries.""" + creator = make_async_creator() + engine = create_async_engine(ASYNC_DIALECT_URL, creator=creator, pool_size=1) + try: + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + await engine.dispose() + + @patch("fides.common.engine_creators.AsyncAdapt_asyncpg_connection") + @patch("fides.common.engine_creators.await_only", side_effect=lambda coro: coro) + @patch("fides.common.engine_creators.asyncpg") + def test_ssl_context_not_overwritten_by_async_params( + self, mock_asyncpg, mock_await, mock_adapt_conn, self_signed_cert + ) -> None: + """When both sslrootcert and sslmode are configured, the SSLContext + must not be overwritten by the raw ssl string from async_params.""" + with patch.object( + CONFIG.database, + "params", + {"sslmode": "verify-full", "sslrootcert": self_signed_cert}, + ): + creator = make_async_creator() + creator() + + connect_kwargs = mock_asyncpg.connect.call_args[1] + assert isinstance(connect_kwargs["ssl"], ssl.SSLContext), ( + f"Expected SSLContext but got {connect_kwargs['ssl']!r} — " + "async_params overwrote the ssl_context" + ) diff --git a/tests/lib/test_session.py b/tests/lib/test_session.py index 8d74f94cb59..7293cd99461 100644 --- a/tests/lib/test_session.py +++ b/tests/lib/test_session.py @@ -1,6 +1,8 @@ import pytest +from sqlalchemy import text from fides.api.db import session +from fides.common.engine_creators import make_sync_creator from fides.config import get_config @@ -20,3 +22,93 @@ def test_get_session_test_modes(self, test_mode: bool) -> None: db_engine = session.get_db_engine(config=config) config.test_mode = original_value assert db_engine + + def test_get_engine_with_creator(self) -> None: + """Engine created via creator= can execute queries.""" + creator = make_sync_creator() + engine = session.get_db_engine(creator=creator, pool_size=1) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + def test_creator_with_keepalives_raises(self) -> None: + """Passing both creator and keepalives params is an error.""" + creator = make_sync_creator() + with pytest.raises( + ValueError, + match="keepalives_idle/interval/count cannot be used with creator", + ): + session.get_db_engine(creator=creator, keepalives_idle=30) + + def test_creator_with_database_uri_raises(self) -> None: + """Passing both creator and database_uri is an error.""" + creator = make_sync_creator() + with pytest.raises( + ValueError, + match="database_uri/config cannot be used with creator", + ): + session.get_db_engine( + creator=creator, database_uri="postgresql://localhost/db" + ) + + def test_creator_with_config_raises(self) -> None: + """Passing both creator and config is an error.""" + creator = make_sync_creator() + config = get_config() + with pytest.raises( + ValueError, + match="database_uri/config cannot be used with creator", + ): + session.get_db_engine(creator=creator, config=config) + + def test_config_with_keepalives(self) -> None: + """URI path with keepalives produces a working engine.""" + config = get_config() + engine = session.get_db_engine( + config=config, + pool_size=1, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + def test_pool_recycle_passed_to_engine(self) -> None: + """pool_recycle is forwarded to the underlying QueuePool.""" + creator = make_sync_creator() + engine = session.get_db_engine(creator=creator, pool_size=1, pool_recycle=900) + try: + assert engine.pool._recycle == 900 + finally: + engine.dispose() + + def test_pool_recycle_default(self) -> None: + """Default pool_recycle (None) leaves SQLAlchemy's default of -1.""" + creator = make_sync_creator() + engine = session.get_db_engine(creator=creator, pool_size=1) + try: + assert engine.pool._recycle == -1 + finally: + engine.dispose() + + def test_disable_pooling(self) -> None: + """disable_pooling uses NullPool — no connections are kept.""" + from sqlalchemy.pool import NullPool + + config = get_config() + engine = session.get_db_engine(config=config, disable_pooling=True) + try: + assert isinstance(engine.pool, NullPool) + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() diff --git a/tests/ops/api/test_deps.py b/tests/ops/api/test_deps.py index 8dc11215c28..802bba1af1f 100644 --- a/tests/ops/api/test_deps.py +++ b/tests/ops/api/test_deps.py @@ -31,6 +31,14 @@ def mock_config_changed_db_engine_settings(): CONFIG.database.api_engine_max_overflow = max_overflow +@pytest.fixture +def mock_config_pool_recycle(): + original = CONFIG.database.pool_recycle + CONFIG.database.pool_recycle = 1800 + yield + CONFIG.database.pool_recycle = original + + @pytest.mark.usefixtures("mock_config") def test_get_cache_not_enabled(): with pytest.raises(RedisNotConfigured): @@ -53,3 +61,12 @@ def test_get_api_session(config_fixture, request): pool: QueuePool = engine.pool assert pool.size() == pool_size assert pool._max_overflow == max_overflow + + +@pytest.mark.usefixtures("mock_config_pool_recycle") +def test_get_api_session_pool_recycle(): + session_management._engine = None + session: Session = get_api_session() + engine: Engine = session.get_bind() + pool: QueuePool = engine.pool + assert pool._recycle == CONFIG.database.pool_recycle diff --git a/tests/ops/tasks/test_database_task.py b/tests/ops/tasks/test_database_task.py index c2907ade17a..ce799a0b64d 100644 --- a/tests/ops/tasks/test_database_task.py +++ b/tests/ops/tasks/test_database_task.py @@ -19,9 +19,12 @@ def mock_config_changed_db_engine_settings(self): CONFIG.database.task_engine_pool_size = pool_size + 5 max_overflow = CONFIG.database.task_engine_max_overflow CONFIG.database.task_engine_max_overflow = max_overflow + 5 + pool_recycle = CONFIG.database.pool_recycle + CONFIG.database.pool_recycle = 1800 yield CONFIG.database.task_engine_pool_size = pool_size CONFIG.database.task_engine_max_overflow = max_overflow + CONFIG.database.pool_recycle = pool_recycle @pytest.fixture def recovering_session_maker(self): @@ -42,6 +45,15 @@ def always_failing_session_maker(self): mock_maker.side_effect = OperationalError("connection failed", None, None) return mock_maker + @pytest.mark.usefixtures("mock_config_changed_db_engine_settings") + def test_task_engine_pool_recycle(self): + """pool_recycle from config is applied to the task engine.""" + task = DatabaseTask() + task._task_engine = None + task._sessionmaker = None + task.get_new_session() + assert task._task_engine.pool._recycle == CONFIG.database.pool_recycle + def test_retry_on_operational_error(self, recovering_session_maker): """Test that session creation retries on OperationalError""" diff --git a/uv.lock b/uv.lock index e4982e42656..b2ce44a5a7c 100644 --- a/uv.lock +++ b/uv.lock @@ -1177,7 +1177,7 @@ dev = [ { name = "faker", specifier = "==14.1.0" }, { name = "freezegun", specifier = "==1.5.5" }, { name = "gitpython", specifier = "==3.1.41" }, - { name = "moto", extras = ["s3"], specifier = "==5.1.22" }, + { name = "moto", extras = ["s3", "secretsmanager"], specifier = "==5.1.22" }, { name = "mypy", specifier = "==1.10.0" }, { name = "nox", specifier = ">=2025.11" }, { name = "pre-commit", specifier = "==2.20.0" },