diff --git a/CLAUDE.md b/CLAUDE.md index 12599b8..b725391 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -25,7 +25,7 @@ The CI `DB_DSN` format: `postgresql+asyncpg://postgres:postgres@localhost:5432/p The package (`db_retry/`) exposes five public symbols via `__init__.py`: -- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the exception chain via `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override. +- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the outer `__cause__`/`__context__` chain to find any `DBAPIError`, then inspects `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). The chain walk lets retries fire when the `DBAPIError` is re-raised by a wrapper (e.g. advanced-alchemy's `wrap_sqlalchemy_exception()` surfacing it as `RepositoryError`/`IntegrityError`). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override. - **`build_connection_factory`** (`connections.py`) — returns an async callable suitable for SQLAlchemy's `async_engine_from_config`. Handles multi-host DSNs by randomizing host order (load balancing) and attempting all hosts on timeout before raising `TargetServerAttributeNotMatched`. diff --git a/README.md b/README.md index 6522b2a..e7c393c 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,7 @@ export DB_RETRY_RETRIES_NUMBER=5 ### Retry Decorator - `@postgres_retry` - Decorator for async functions that should retry on database errors (uses `DB_RETRY_RETRIES_NUMBER`) - `@postgres_retry(retries=N)` - Override retry count per callsite +- Retries also fire when the retriable `asyncpg` error is wrapped by [`advanced-alchemy`](https://github.com/litestar-org/advanced-alchemy)'s `wrap_sqlalchemy_exception()` (i.e. surfaced as `RepositoryError` / `IntegrityError`); the handler walks the `__cause__` / `__context__` chain. ### Connection Utilities - `build_connection_factory(url, timeout)` - Creates a connection factory for multi-host setups diff --git a/db_retry/retry.py b/db_retry/retry.py index eb9cee6..fb640b4 100644 --- a/db_retry/retry.py +++ b/db_retry/retry.py @@ -12,15 +12,23 @@ logger = logging.getLogger(__name__) -def _retry_handler(exception: BaseException) -> bool: - if ( +def _is_retriable_dbapi_error(exception: BaseException) -> bool: + return ( isinstance(exception, DBAPIError) - and hasattr(exception, "orig") and exception.orig is not None and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) - ): - logger.debug("postgres_retry, retrying") - return True + ) + + +def _retry_handler(exception: BaseException) -> bool: + current: BaseException | None = exception + seen: set[int] = set() + while current is not None and id(current) not in seen: + seen.add(id(current)) + if _is_retriable_dbapi_error(current): + logger.debug("postgres_retry, retrying") + return True + current = current.__cause__ or current.__context__ logger.debug("postgres_retry, giving up on retry") return False diff --git a/pyproject.toml b/pyproject.toml index 12989f3..6e414e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dev = [ "pytest", "pytest-cov", "pytest-asyncio", + "advanced-alchemy", ] lint = [ "ruff", diff --git a/tests/test_retry.py b/tests/test_retry.py index 7abe751..ff6185a 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,5 +1,6 @@ import pytest import sqlalchemy +from advanced_alchemy.exceptions import RepositoryError, wrap_sqlalchemy_exception from sqlalchemy.exc import DBAPIError from sqlalchemy.ext import asyncio as sa_async @@ -44,6 +45,49 @@ async def raise_error() -> None: assert call_count == expected_calls +@pytest.mark.parametrize( + ("error_code", "expected_calls"), + [ + ("08000", 2), + ("08003", 2), + ("40001", 2), + ("40002", 1), + ], +) +async def test_postgres_retry_advanced_alchemy( + async_engine: sa_async.AsyncEngine, + error_code: str, + expected_calls: int, +) -> None: + async with async_engine.connect() as connection: + await connection.execute( + sqlalchemy.text( + f""" + CREATE OR REPLACE FUNCTION raise_error() + RETURNS VOID AS $$ + BEGIN + RAISE SQLSTATE '{error_code}'; + END; + $$ LANGUAGE plpgsql; + """, + ), + ) + + call_count = 0 + + @postgres_retry + async def raise_error() -> None: + nonlocal call_count + call_count += 1 + with wrap_sqlalchemy_exception(): + await connection.execute(sqlalchemy.text("SELECT raise_error()")) + + with pytest.raises(RepositoryError): + await raise_error() + + assert call_count == expected_calls + + async def test_postgres_retry_with_retries(async_engine: sa_async.AsyncEngine) -> None: async with async_engine.connect() as connection: await connection.execute(