Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The CI `DB_DSN` format: `postgresql+asyncpg://postgres:postgres@localhost:5432/p

The package (`db_retry/`) exposes five public symbols via `__init__.py`:

- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the exception chain via `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override.
- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the outer `__cause__`/`__context__` chain to find any `DBAPIError`, then inspects `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). The chain walk lets retries fire when the `DBAPIError` is re-raised by a wrapper (e.g. advanced-alchemy's `wrap_sqlalchemy_exception()` surfacing it as `RepositoryError`/`IntegrityError`). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override.

- **`build_connection_factory`** (`connections.py`) — returns an async callable suitable for SQLAlchemy's `async_engine_from_config`. Handles multi-host DSNs by randomizing host order (load balancing) and attempting all hosts on timeout before raising `TargetServerAttributeNotMatched`.

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ export DB_RETRY_RETRIES_NUMBER=5
### Retry Decorator
- `@postgres_retry` - Decorator for async functions that should retry on database errors (uses `DB_RETRY_RETRIES_NUMBER`)
- `@postgres_retry(retries=N)` - Override retry count per callsite
- Retries also fire when the retriable `asyncpg` error is wrapped by [`advanced-alchemy`](https://github.com/litestar-org/advanced-alchemy)'s `wrap_sqlalchemy_exception()` (i.e. surfaced as `RepositoryError` / `IntegrityError`); the handler walks the `__cause__` / `__context__` chain.

### Connection Utilities
- `build_connection_factory(url, timeout)` - Creates a connection factory for multi-host setups
Expand Down
20 changes: 14 additions & 6 deletions db_retry/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
logger = logging.getLogger(__name__)


def _retry_handler(exception: BaseException) -> bool:
if (
def _is_retriable_dbapi_error(exception: BaseException) -> bool:
return (
isinstance(exception, DBAPIError)
and hasattr(exception, "orig")
and exception.orig is not None
and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError))
):
logger.debug("postgres_retry, retrying")
return True
)


def _retry_handler(exception: BaseException) -> bool:
current: BaseException | None = exception
seen: set[int] = set()
while current is not None and id(current) not in seen:
seen.add(id(current))
if _is_retriable_dbapi_error(current):
logger.debug("postgres_retry, retrying")
return True
current = current.__cause__ or current.__context__

logger.debug("postgres_retry, giving up on retry")
return False
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"advanced-alchemy",
]
lint = [
"ruff",
Expand Down
44 changes: 44 additions & 0 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import sqlalchemy
from advanced_alchemy.exceptions import RepositoryError, wrap_sqlalchemy_exception
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext import asyncio as sa_async

Expand Down Expand Up @@ -44,6 +45,49 @@ async def raise_error() -> None:
assert call_count == expected_calls


@pytest.mark.parametrize(
("error_code", "expected_calls"),
[
("08000", 2),
("08003", 2),
("40001", 2),
("40002", 1),
],
)
async def test_postgres_retry_advanced_alchemy(
async_engine: sa_async.AsyncEngine,
error_code: str,
expected_calls: int,
) -> None:
async with async_engine.connect() as connection:
await connection.execute(
sqlalchemy.text(
f"""
CREATE OR REPLACE FUNCTION raise_error()
RETURNS VOID AS $$
BEGIN
RAISE SQLSTATE '{error_code}';
END;
$$ LANGUAGE plpgsql;
""",
),
)

call_count = 0

@postgres_retry
async def raise_error() -> None:
nonlocal call_count
call_count += 1
with wrap_sqlalchemy_exception():
await connection.execute(sqlalchemy.text("SELECT raise_error()"))

with pytest.raises(RepositoryError):
await raise_error()

assert call_count == expected_calls


async def test_postgres_retry_with_retries(async_engine: sa_async.AsyncEngine) -> None:
async with async_engine.connect() as connection:
await connection.execute(
Expand Down
Loading