diff --git a/README.md b/README.md index 5c310d1..3dec825 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,32 @@ pip install httpware[all] # everything declared above (pydantic, msgsp ## Quickstart -> Requires: `pip install httpware[pydantic]` +**Async usage:** + +```python +import asyncio + +from httpware import AsyncClient + +async def main() -> None: + async with AsyncClient(base_url="https://example.test") as client: + response = await client.get("/users/42") + print(response.json()) + +asyncio.run(main()) +``` + +**Sync usage:** + +```python +from httpware import Client + +with Client(base_url="https://example.test") as client: + response = client.get("/users/42") + print(response.json()) +``` + +Typed decoding via `response_model=` works in both worlds — requires `pip install httpware[pydantic]`: ```python from httpware import AsyncClient diff --git a/docs/errors.md b/docs/errors.md index 06a117f..a327190 100644 --- a/docs/errors.md +++ b/docs/errors.md @@ -4,6 +4,8 @@ For the resilience-specific errors (`RetryBudgetExhaustedError`, `BulkheadFullError`) see the [Resilience reference](resilience.md). +The status-keyed exception tree is shared between `Client` and `AsyncClient`. Catching `NotFoundError` in sync code uses the same import as catching it in async code (`from httpware import NotFoundError`). + ## The exception tree ``` diff --git a/docs/index.md b/docs/index.md index 8bc5f6c..581fb79 100644 --- a/docs/index.md +++ b/docs/index.md @@ -19,9 +19,34 @@ pip install httpware[msgspec] # MsgspecDecoder ## First request +**Async usage:** + ```python import asyncio +from httpware import AsyncClient + +async def main() -> None: + async with AsyncClient(base_url="https://example.test") as client: + response = await client.get("/users/42") + print(response.json()) + +asyncio.run(main()) +``` + +**Sync usage:** + +```python +from httpware import Client + +with Client(base_url="https://example.test") as client: + response = client.get("/users/42") + print(response.json()) +``` + +Typed decoding via `response_model=` works the same way in both worlds: + +```python from httpware import AsyncClient from pydantic import BaseModel @@ -35,9 +60,6 @@ async def main() -> None: async with AsyncClient(base_url="https://api.example.com") as client: user = await client.get("/users/1", response_model=User) print(user.name) - - -asyncio.run(main()) ``` ### With resilience middleware diff --git a/docs/middleware.md b/docs/middleware.md index 5199fcf..98b0d44 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -136,6 +136,62 @@ After this runs, every `httpware` HTTP call gets an `HTTP ` span from th For production, swap `ConsoleSpanExporter` for your OTLP/Jaeger/Zipkin exporter. See the [OpenTelemetry Python docs](https://opentelemetry.io/docs/languages/python/) for the full SDK setup. +## Sync middleware + +The same protocol shape, sync flavor. Use these when wiring middleware into a sync `Client` instead of `AsyncClient`. + +```python +from httpware import Middleware, Next, before_request, after_response, on_error +from httpware.middleware.chain import compose +``` + +A sync `Middleware` is a structural protocol — any callable with the right signature satisfies it: + +```python +import httpx2 + +from httpware import Client +from httpware.middleware import Next + + +class LoggingMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + print(f"-> {request.method} {request.url}") + response = next(request) + print(f"<- {response.status_code}") + return response + + +with Client(base_url="https://api.example.com", middleware=[LoggingMiddleware()]) as client: + client.get("/users/1") +``` + +Phase decorators (`@before_request`, `@after_response`, `@on_error`) have the same semantics as their `@async_*` siblings, but wrap sync functions: + +```python +import uuid + +import httpx2 + +from httpware import Client, before_request + + +@before_request +def add_request_id(request: httpx2.Request) -> httpx2.Request: + return httpx2.Request( + request.method, + request.url, + headers={**request.headers, "X-Request-ID": uuid.uuid4().hex}, + content=request.content, + ) + + +with Client(base_url="https://api.example.com", middleware=[add_request_id]) as client: + client.get("/users/1") +``` + +Sync and async middleware classes do not interop: a `Middleware` cannot be passed to `AsyncClient(middleware=...)` and vice versa. Pick the flavor matching your client. + ## See also - **`planning/engineering.md` §3 (Seam A)** — the formal protocol contract and why the chain is frozen at construction. diff --git a/docs/resilience.md b/docs/resilience.md index 70352ad..79e0d37 100644 --- a/docs/resilience.md +++ b/docs/resilience.md @@ -168,6 +168,62 @@ Flipping the order (`[AsyncRetry, AsyncBulkhead]`) means each retry attempt grab Cross-cutting middleware that emit per-call state (e.g., the Request-ID middleware in the [Middleware guide](middleware.md)) should sit outside `AsyncRetry` for the same reason — so all attempts of one call share one ID rather than getting a fresh ID per attempt. +## Sync Retry and Bulkhead + +The sync flavors mirror the async ones for use with `Client`. Same parameter set, same defaults, same `RetryBudget` (which is safe to share across sync and async clients in the same process). + +### `Retry` + +```python +from httpware.middleware.resilience import Retry +``` + +| Parameter | Default | Effect | +|---|---|---| +| `max_attempts` | `3` | Total tries (including the first). `1` disables retries entirely; `<1` raises `ValueError`. | +| `base_delay` | `0.1` (s) | Floor for the full-jitter exponential backoff. | +| `max_delay` | `5.0` (s) | Ceiling for backoff. | +| `retry_status_codes` | `frozenset({408, 429, 502, 503, 504})` | Status codes considered retryable. | +| `retry_methods` | `frozenset({"GET", "HEAD", "OPTIONS", "PUT", "DELETE"})` | Idempotent methods only by default. POST excluded; pass an explicit frozenset including `"POST"` to retry it. | +| `respect_retry_after` | `True` | When the response carries a `Retry-After` header on a retryable status, sleep for the header value (clamped to `max_delay`) instead of the jittered backoff. | +| `budget` | `RetryBudget()` (default-configured) | The token bucket. Pass a shared `RetryBudget` instance to apply one budget across multiple clients — sync, async, or both. | + +`Retry` uses `time.sleep` between attempts. `Retry-After`, streaming-body refusal, exhaustion behavior, and `RetryBudgetExhaustedError` semantics are identical to `AsyncRetry`. + +For a whole-attempt wall-clock bound, use `httpx2.Timeout` on the wrapped client or pass `timeout=` per request. `httpware` does not own a structured-cancellation timeout knob. + +### `Bulkhead` + +```python +from httpware.middleware.resilience import Bulkhead +``` + +| Parameter | Default | Effect | +|---|---|---| +| `max_concurrent` | **REQUIRED** | Maximum in-flight requests. `<1` raises `ValueError`. | +| `acquire_timeout` | `1.0` (s) | How long to wait for a slot before raising `BulkheadFullError`. `None` waits forever; `0` fails fast. `<0` raises `ValueError`. | + +`Bulkhead` is backed by `threading.Semaphore`. Slot release follows the same `try/finally` contract as `AsyncBulkhead` — success, exception, and (in sync land) interrupt-style exceptions all release the slot. + +> **Per-world Bulkhead.** A `Bulkhead` (sync) and an `AsyncBulkhead` are separate primitives backed by `threading.Semaphore` and `asyncio.Semaphore` respectively. A single Bulkhead instance cannot enforce a joint cap across sync + async clients in the same process. If you need that, create both with the same `max_concurrent`; the OS will not coordinate the two but the policy intent is documented. + +### Composition with sync `Client` + +```python +from httpware import Client +from httpware.middleware.resilience import Bulkhead, Retry + + +with Client( + base_url="https://api.example.com", + middleware=[ + Bulkhead(max_concurrent=10), + Retry(), + ], +) as client: + client.get("/users/1") +``` + ## See also - **[Middleware guide](middleware.md)** — write your own resilience middleware against the same protocol `AsyncRetry` and `AsyncBulkhead` use. diff --git a/docs/testing.md b/docs/testing.md index 105b6b5..677499f 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -28,6 +28,29 @@ The handler can be sync or async; `httpx2.MockTransport` supports both. The test If you use `pytest-asyncio` in auto-mode (`asyncio_mode = "auto"` under `[tool.pytest.ini_options]`), async test functions don't need the `@pytest.mark.asyncio` decorator. +### Sync `Client` + +The same pattern works for the sync `Client` — pass an `httpx2.Client` (not `httpx2.AsyncClient`) built on `httpx2.MockTransport`: + +```python +from http import HTTPStatus + +import httpx2 + +from httpware import Client + + +def test_get_returns_typed_response() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, json={"ok": True}) + + with Client(httpx2_client=httpx2.Client(transport=httpx2.MockTransport(handler))) as client: + response = client.get("https://example.test/x") + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"ok": True} +``` + ## Recording / stateful handlers For tests that need to vary the response by call count or assert on the requests that came in, use a handler with instance state: diff --git a/planning/engineering.md b/planning/engineering.md index a8cffaa..6aeedfb 100644 --- a/planning/engineering.md +++ b/planning/engineering.md @@ -8,6 +8,8 @@ This doc is the single distilled reference for `httpware` design rationale, prot The next release renames the async middleware surface to use the `Async*`/`async_*` prefix (aligning with httpx2's convention) and removes the seldom-used `attempt_timeout=` kwarg from `AsyncRetry` — see `planning/specs/2026-06-07-sync-client-design.md` for the rationale. +The same release also adds a sync `Client` with full feature parity (typed decoding, middleware chain, `Retry`/`Bulkhead`, `stream()`). `RetryBudget` is now thread-safe (one class, both worlds). Sync `Bulkhead` uses `threading.Semaphore` and cannot share an instance with `AsyncBulkhead`. See `planning/specs/2026-06-07-sync-client-design.md`. + The 0.1.0 release attempted to own a full abstraction over the underlying HTTP client. v0.2 walks that back: `httpx2` is part of the public surface. ## 2. Architectural invariants (CI-enforced) @@ -28,10 +30,10 @@ A protocol seam is a documented internal boundary. AI agents and contributors mu The 0.1.0 seams numbered 1 (Middleware↔Transport) and 4 (Transport↔httpx2) have collapsed into the `AsyncClient` terminal — there is no transport abstraction in v0.2. -### Seam A: `AsyncClient ↔ AsyncMiddleware` +### Seam A: `Client`/`AsyncClient` ↔ `Middleware`/`AsyncMiddleware` - **Where:** `src/httpware/client.py` ↔ `src/httpware/middleware/`. -- **Contract:** the `AsyncMiddleware` chain is composed once via `compose_async` at `AsyncClient.__init__` and frozen for the client's lifetime. The chain bottom (the "terminal") is internal: it calls `self._httpx2_client.send(request)`, maps `httpx2` errors to `httpware` errors, and raises a `StatusError` subclass on 4xx/5xx. The continuation type passed to each middleware is `AsyncNext`. +- **Contract:** the middleware chain is composed once at client construction and frozen for the client's lifetime. Both worlds follow the same contract; the only difference is the per-world type: `AsyncClient` composes `AsyncMiddleware` via `compose_async` (the continuation type is `AsyncNext`), and `Client` composes `Middleware` via `compose` (the continuation type is `Next`). Both `compose` and `compose_async` live in `src/httpware/middleware/chain.py`. The chain bottom (the "terminal") is internal: it calls `self._httpx2_client.send(request)`, maps `httpx2` errors to `httpware` errors, and raises a `StatusError` subclass on 4xx/5xx. Same lifecycle rules in both worlds. - **Rule:** mutating the chain after construction is not supported. Per-request behavior goes through `httpx2.Request.extensions` or through `extensions=` kwargs at call sites. ### Seam B: `AsyncClient ↔ ResponseDecoder` @@ -65,29 +67,29 @@ The error-mapping table (what `httpx2` exception maps to which `httpware` except ## 5. Module layout -Current tree (v0.2): +Current tree: ```text src/httpware/ -├── __init__.py # public exports +├── __init__.py # public exports (both worlds at top level) ├── py.typed -├── client.py # AsyncClient -├── errors.py # status-keyed exception tree + NetworkError + RetryBudgetExhaustedError + BulkheadFullError +├── client.py # Client (sync) + AsyncClient (async) +├── errors.py # status-keyed exception tree (shared) ├── middleware/ -│ ├── __init__.py # AsyncMiddleware protocol, AsyncNext type, @async_before_request/@async_after_response/@async_on_error -│ ├── chain.py # compose_async(middleware, terminal) -> AsyncNext +│ ├── __init__.py # Middleware + AsyncMiddleware, Next + AsyncNext, decorators +│ ├── chain.py # compose + compose_async │ └── resilience/ -│ ├── __init__.py # re-exports AsyncBulkhead, AsyncRetry, RetryBudget -│ ├── bulkhead.py # AsyncBulkhead middleware (concurrency limiter) -│ ├── budget.py # RetryBudget (Finagle-style token bucket) -│ ├── retry.py # AsyncRetry middleware -│ └── _backoff.py # full-jitter exponential backoff helper (private) -├── decoders/ -│ ├── __init__.py # ResponseDecoder protocol -│ ├── pydantic.py # PydanticDecoder (extra: pydantic) -│ └── msgspec.py # MsgspecDecoder (extra: msgspec) +│ ├── __init__.py # re-exports both worlds + RetryBudget +│ ├── bulkhead.py # Bulkhead + AsyncBulkhead +│ ├── budget.py # RetryBudget (thread-safe; shared) +│ ├── retry.py # Retry + AsyncRetry +│ └── _backoff.py # full-jitter helper (shared) +├── decoders/ # shared (ResponseDecoder + adapters) └── _internal/ - └── import_checker.py # is_msgspec_installed, is_pydantic_installed + ├── exception_mapping.py # map_httpx2_exception (shared) + ├── import_checker.py # is_*_installed flags + ├── observability.py # _emit_event + └── status.py # _raise_on_status_error, _is_streaming_body_*, STREAMING_BODY_MARKER ``` **Deleted relative to 0.1.0:** `request.py`, `response.py`, `config.py`, `transports/` (Transport protocol + Httpx2Transport), `_internal/auth.py`, `_internal/chain.py`. The `RecordedTransport` testing helper is gone; tests inject `httpx2.MockTransport` via `httpx2_client=` instead. diff --git a/planning/releases/0.8.0.md b/planning/releases/0.8.0.md new file mode 100644 index 0000000..216c9d7 --- /dev/null +++ b/planning/releases/0.8.0.md @@ -0,0 +1,61 @@ +# httpware 0.8.0 — Sync Client + httpx2-aligned naming + +**Breaking release.** Renames the async middleware surface to use the `Async*`/`async_*` prefix (matching httpx2's convention), drops `Retry(attempt_timeout=...)`, and adds a fully-featured sync `Client`. + +If you have existing async code, migration is one mechanical pass through your imports — see "Breaking changes" below. + +## What's new + +- **Sync `Client`.** Full parity with `AsyncClient`: typed response decoding, middleware chain, `Retry` + `Bulkhead`, `stream()` context manager, lifecycle (`with` + `close()`), and `httpx2.Client` injection. Designed for CLI tools, scripts, Django sync views, Jupyter, and threaded service workers. +- **Sync `Middleware` + `Next` + decorators.** `from httpware import Middleware, Next, before_request, after_response, on_error`. Same protocol shape as async; bodies are sync. +- **Sync `Retry` and `Bulkhead`.** Same resilience semantics as their async siblings, with `time.sleep` and `threading.Semaphore`. Sync `Retry` shares `RetryBudget` with async — one instance is safe across both worlds. +- **`RetryBudget` is now thread-safe** via an internal `threading.Lock`. Async users see no behavioral difference; the overhead is invisible (~50–100 ns per op). +- **Shared helpers in `_internal/`.** `map_httpx2_exception`, `_raise_on_status_error`, the streaming-body marker, and the body predicates moved to `_internal/exception_mapping.py` and `_internal/status.py`. No public-API change other than the exports listed below. + +## Breaking changes + +### Renames + +| Old name | New name | +|---|---| +| `httpware.Middleware` | `httpware.AsyncMiddleware` | +| `httpware.Next` | `httpware.AsyncNext` | +| `httpware.Retry` | `httpware.AsyncRetry` | +| `httpware.Bulkhead` | `httpware.AsyncBulkhead` | +| `httpware.before_request` | `httpware.async_before_request` | +| `httpware.after_response` | `httpware.async_after_response` | +| `httpware.on_error` | `httpware.async_on_error` | +| `httpware.middleware.chain.compose` | `httpware.middleware.chain.compose_async` | + +### Removals + +- `Retry(attempt_timeout=...)` / `AsyncRetry(attempt_timeout=...)` is **removed**. It used `asyncio.timeout` to bound the whole attempt as a structured cancellation; this had no clean sync equivalent and is mostly covered by `httpx2.Timeout` (per-phase I/O bounds) for typical use cases. Users who genuinely need whole-attempt wall-clock bounds can compose their own timeout middleware. + +### New names that previously meant something else + +The unprefixed `Middleware`, `Next`, `Retry`, `Bulkhead`, `before_request`, `after_response`, `on_error` now refer to **sync** types. Code that imports them and expects async behavior will break at type-check time (or at the first `await` site). + +## Migration + +A one-pass sed/regex covers most of the work: + +```bash +# in your project root: +git ls-files '*.py' | xargs sed -i.bak \ + -e 's/from httpware import \(.*\)\bMiddleware\b/from httpware import \1AsyncMiddleware/g' \ + -e 's/from httpware import \(.*\)\bNext\b/from httpware import \1AsyncNext/g' \ + -e 's/from httpware import \(.*\)\bRetry\b/from httpware import \1AsyncRetry/g' \ + -e 's/from httpware import \(.*\)\bBulkhead\b/from httpware import \1AsyncBulkhead/g' \ + -e 's/from httpware import \(.*\)\bbefore_request\b/from httpware import \1async_before_request/g' \ + -e 's/from httpware import \(.*\)\bafter_response\b/from httpware import \1async_after_response/g' \ + -e 's/from httpware import \(.*\)\bon_error\b/from httpware import \1async_on_error/g' +``` + +Then update the symbol references in the file bodies (your type checker will guide you). If you were using `Retry(attempt_timeout=...)`, remove the kwarg and rely on `httpx2.Timeout` or write a minimal timeout middleware. + +## References + +- Design spec: [`planning/specs/2026-06-07-sync-client-design.md`](../specs/2026-06-07-sync-client-design.md) +- Implementation plan: [`planning/plans/2026-06-07-sync-client-plan.md`](../plans/2026-06-07-sync-client-plan.md) +- Engineering notes: [`planning/engineering.md`](../engineering.md) §3 Seam A, §5 module layout +- Source spec parent (httpx convention): [`planning/archive/specs/2026-06-03-thin-httpx2-wrapper-design.md`](../archive/specs/2026-06-03-thin-httpx2-wrapper-design.md) diff --git a/src/httpware/__init__.py b/src/httpware/__init__.py index 03a7bf8..a2a9dd8 100644 --- a/src/httpware/__init__.py +++ b/src/httpware/__init__.py @@ -1,6 +1,6 @@ -"""httpware — thin async HTTP client wrapper over httpx2.""" +"""httpware — thin async + sync HTTP client wrapper over httpx2.""" -from httpware.client import AsyncClient +from httpware.client import AsyncClient, Client from httpware.decoders import ResponseDecoder from httpware.errors import ( STATUS_TO_EXCEPTION, @@ -26,11 +26,16 @@ from httpware.middleware import ( AsyncMiddleware, AsyncNext, + Middleware, + Next, + after_response, async_after_response, async_before_request, async_on_error, + before_request, + on_error, ) -from httpware.middleware.resilience import AsyncBulkhead, AsyncRetry, RetryBudget +from httpware.middleware.resilience import AsyncBulkhead, AsyncRetry, Bulkhead, Retry, RetryBudget __all__ = [ @@ -41,16 +46,21 @@ "AsyncNext", "AsyncRetry", "BadRequestError", + "Bulkhead", "BulkheadFullError", + "Client", "ClientError", "ClientStatusError", "ConflictError", "ForbiddenError", "InternalServerError", + "Middleware", "NetworkError", + "Next", "NotFoundError", "RateLimitedError", "ResponseDecoder", + "Retry", "RetryBudget", "RetryBudgetExhaustedError", "ServerStatusError", @@ -60,7 +70,10 @@ "TransportError", "UnauthorizedError", "UnprocessableEntityError", + "after_response", "async_after_response", "async_before_request", "async_on_error", + "before_request", + "on_error", ] diff --git a/src/httpware/_internal/exception_mapping.py b/src/httpware/_internal/exception_mapping.py new file mode 100644 index 0000000..035d422 --- /dev/null +++ b/src/httpware/_internal/exception_mapping.py @@ -0,0 +1,28 @@ +"""httpx2 -> httpware exception mapping. + +Pure function used by both Client._terminal and AsyncClient._terminal, +and by both stream() methods. Clause ordering: TimeoutException -> +InvalidURL/CookieConflict -> NetworkError -> HTTPError (subclass before +parent so the right type wins). +""" + +import httpx2 + +from httpware.errors import NetworkError, TimeoutError, TransportError # noqa: A004 + + +def map_httpx2_exception(exc: BaseException) -> NetworkError | TimeoutError | TransportError: + """Map an httpx2 exception to its httpware equivalent. + + Order is significant: more-specific httpx2 types must match before more + general ones. We return the mapped exception; the caller does `raise ... from exc`. + """ + if isinstance(exc, httpx2.TimeoutException): + return TimeoutError(str(exc)) + if isinstance(exc, (httpx2.InvalidURL, httpx2.CookieConflict)): + return TransportError(str(exc)) + if isinstance(exc, httpx2.NetworkError): + return NetworkError(str(exc)) + if isinstance(exc, httpx2.HTTPError): + return TransportError(str(exc)) + return TransportError(str(exc)) # pragma: no cover — defensive default; httpx2.HTTPError is the root diff --git a/src/httpware/_internal/status.py b/src/httpware/_internal/status.py new file mode 100644 index 0000000..f7465f0 --- /dev/null +++ b/src/httpware/_internal/status.py @@ -0,0 +1,47 @@ +"""Status-code dispatch + streaming-body detection. + +Shared by Client and AsyncClient. The STREAMING_BODY_MARKER is the public +extensions key both Retry and AsyncRetry read; renaming it is breaking. +""" + +from http import HTTPStatus + +import httpx2 + +from httpware.errors import STATUS_TO_EXCEPTION, ClientStatusError, ServerStatusError + + +STREAMING_BODY_MARKER = "httpware.streaming_body" +"""Set on ``httpx2.Request.extensions`` when content/data/files is a non-replayable +iterable (async-iterable for AsyncClient, sync iterator/generator for Client). +Retry / AsyncRetry read this marker to refuse retrying a streamed-body request +(the consumed iterator cannot replay across attempts).""" + + +def _raise_on_status_error(response: httpx2.Response) -> None: + """Raise the appropriate StatusError subclass for a 4xx/5xx response. No-op for 2xx/3xx.""" + status = response.status_code + if HTTPStatus.BAD_REQUEST <= status < 600: # noqa: PLR2004 — 600 is the synthetic upper bound for 5xx + exc_class = STATUS_TO_EXCEPTION.get( + status, + ClientStatusError if status < HTTPStatus.INTERNAL_SERVER_ERROR else ServerStatusError, + ) + raise exc_class(response) + + +def _is_streaming_body_async(value: object) -> bool: + """Return True if value is an async-iterable that cannot be safely replayed for retry.""" + if value is None: + return False + if isinstance(value, (bytes, bytearray, memoryview, str, dict)): + return False + return hasattr(value, "__aiter__") + + +def _is_streaming_body_sync(value: object) -> bool: + """Return True if value is a sync iterable body that cannot be safely replayed for retry.""" + if value is None: + return False + if isinstance(value, (bytes, bytearray, memoryview, str, dict, list, tuple)): + return False + return hasattr(value, "__iter__") diff --git a/src/httpware/client.py b/src/httpware/client.py index d5b8704..3527564 100644 --- a/src/httpware/client.py +++ b/src/httpware/client.py @@ -1,24 +1,24 @@ -"""AsyncClient — the thin httpx2 wrapper.""" +"""Client + AsyncClient — thin httpx2 wrappers with typed decoding and middleware.""" import contextlib import typing -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Iterator, Sequence from http import HTTPStatus import httpx2 from httpware._internal import import_checker -from httpware.decoders import ResponseDecoder -from httpware.errors import ( - STATUS_TO_EXCEPTION, - ClientStatusError, - NetworkError, - ServerStatusError, - TimeoutError, # noqa: A004 - TransportError, +from httpware._internal.exception_mapping import map_httpx2_exception +from httpware._internal.status import ( + STREAMING_BODY_MARKER, + _is_streaming_body_async, + _is_streaming_body_sync, + _raise_on_status_error, ) -from httpware.middleware import AsyncMiddleware, AsyncNext -from httpware.middleware.chain import compose_async +from httpware.decoders import ResponseDecoder +from httpware.errors import TransportError +from httpware.middleware import AsyncMiddleware, AsyncNext, Middleware, Next +from httpware.middleware.chain import compose, compose_async T = typing.TypeVar("T") @@ -26,12 +26,12 @@ _FORWARDED_KWARG_NAMES = ("base_url", "headers", "params", "cookies", "timeout", "limits", "auth") _HTTPX2_CLIENT_CONFLICT_MESSAGE = ( - "AsyncClient(httpx2_client=...) cannot be combined with any of " - f"{_FORWARDED_KWARG_NAMES}; configure the httpx2.AsyncClient you pass instead." + "httpx2_client=... cannot be combined with any of " + f"{_FORWARDED_KWARG_NAMES}; configure the httpx2 client you pass instead." ) _DEFAULT_DECODER_MISSING_MESSAGE = ( - "AsyncClient(decoder=None) defaults to PydanticDecoder, which requires the " + "decoder=None defaults to PydanticDecoder, which requires the " "'pydantic' extra. Either install it (`pip install httpware[pydantic]`) or " "pass an explicit decoder=..." ) @@ -50,41 +50,21 @@ async def _httpx2_exception_mapper() -> AsyncIterator[None]: """Map httpx2 exceptions to httpware exceptions. Shared by AsyncClient._terminal and stream().""" try: yield - except httpx2.TimeoutException as exc: - raise TimeoutError(str(exc)) from exc - except (httpx2.InvalidURL, httpx2.CookieConflict) as exc: - raise TransportError(str(exc)) from exc - except httpx2.NetworkError as exc: - raise NetworkError(str(exc)) from exc except httpx2.HTTPError as exc: - raise TransportError(str(exc)) from exc - - -def _raise_on_status_error(response: httpx2.Response) -> None: - """Raise the appropriate StatusError subclass for a 4xx/5xx response. No-op for 2xx/3xx.""" - status = response.status_code - if HTTPStatus.BAD_REQUEST <= status < 600: # noqa: PLR2004 — 600 is the synthetic upper bound for 5xx - exc_class = STATUS_TO_EXCEPTION.get( - status, - ClientStatusError if status < HTTPStatus.INTERNAL_SERVER_ERROR else ServerStatusError, - ) - raise exc_class(response) - - -STREAMING_BODY_MARKER = "httpware.streaming_body" -"""Key set on ``httpx2.Request.extensions`` by ``_request_with_body`` when content/data/files is an async-iterable. - -``AsyncRetry.__call__`` reads this marker to refuse retrying a streamed-body request -(the consumed iterator cannot replay across attempts).""" + raise map_httpx2_exception(exc) from exc + except (httpx2.InvalidURL, httpx2.CookieConflict) as exc: + raise map_httpx2_exception(exc) from exc -def _is_streaming_body(value: typing.Any) -> bool: - """Return True if value is an async-iterable that cannot be safely replayed for retry.""" - if value is None: - return False - if isinstance(value, (bytes, bytearray, memoryview, str, dict)): - return False - return hasattr(value, "__aiter__") +@contextlib.contextmanager +def _httpx2_exception_mapper_sync() -> Iterator[None]: + """Map httpx2 exceptions to httpware exceptions. Sync sibling of _httpx2_exception_mapper.""" + try: + yield + except httpx2.HTTPError as exc: + raise map_httpx2_exception(exc) from exc + except (httpx2.InvalidURL, httpx2.CookieConflict) as exc: + raise map_httpx2_exception(exc) from exc class AsyncClient: @@ -216,7 +196,7 @@ async def _request_with_body( # noqa: PLR0913, C901 — mirrors httpx2 per-meth if files is not None: kwargs["files"] = files request = self._httpx2_client.build_request(method, url, **kwargs) - if _is_streaming_body(content) or _is_streaming_body(data) or _is_streaming_body(files): + if _is_streaming_body_async(content) or _is_streaming_body_async(data) or _is_streaming_body_async(files): request.extensions[STREAMING_BODY_MARKER] = True return await self.send(request, response_model=response_model) @@ -778,3 +758,694 @@ async def aclose(self) -> None: """ if self._owns_client and not self._httpx2_client.is_closed: await self._httpx2_client.aclose() + + +class Client: + """Sync HTTP client: thin wrapper around httpx2 with typed decoding and middleware.""" + + _httpx2_client: httpx2.Client + _owns_client: bool + _decoder: ResponseDecoder + _user_middleware: tuple[Middleware, ...] + _dispatch: Next + + def __init__( # noqa: PLR0913 — wide constructor is the cost of a single-call API + self, + *, + base_url: str = "", + headers: dict[str, str] | None = None, + params: dict[str, str] | None = None, + cookies: dict[str, str] | None = None, + timeout: httpx2.Timeout | float | None = None, + limits: httpx2.Limits | None = None, + auth: httpx2.Auth | None = None, + httpx2_client: httpx2.Client | None = None, + decoder: ResponseDecoder | None = None, + middleware: Sequence[Middleware] = (), + ) -> None: + if httpx2_client is not None: + forwarded = { + "base_url": base_url, + "headers": headers, + "params": params, + "cookies": cookies, + "timeout": timeout, + "limits": limits, + "auth": auth, + } + if any(value not in (None, "") for value in forwarded.values()): + raise TypeError(_HTTPX2_CLIENT_CONFLICT_MESSAGE) + self._httpx2_client = httpx2_client + self._owns_client = False + else: + kwargs: dict[str, typing.Any] = {} + if base_url: + kwargs["base_url"] = base_url + if headers is not None: + kwargs["headers"] = headers + if params is not None: + kwargs["params"] = params + if cookies is not None: + kwargs["cookies"] = cookies + if timeout is not None: + kwargs["timeout"] = timeout + if limits is not None: + kwargs["limits"] = limits + if auth is not None: + kwargs["auth"] = auth + self._httpx2_client = httpx2.Client(**kwargs) + self._owns_client = True + + self._decoder = decoder if decoder is not None else _default_pydantic_decoder() + self._user_middleware = tuple(middleware) + self._dispatch = compose(self._user_middleware, self._terminal) + + def _terminal(self, request: httpx2.Request) -> httpx2.Response: + try: + with _httpx2_exception_mapper_sync(): + response = self._httpx2_client.send(request) + except RuntimeError as exc: + if "closed" in str(exc): + raise TransportError(str(exc)) from exc + raise + _raise_on_status_error(response) + return response + + def __enter__(self) -> typing.Self: + """Enter the sync context manager; return self.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: + """Exit the sync context manager; close the underlying client only if owned.""" + if self._owns_client and not self._httpx2_client.is_closed: + self._httpx2_client.close() + + def close(self) -> None: + """Close the underlying httpx2 client if we own it. + + Idempotent — safe to call after ``__exit__`` or another ``close()`` call. + Use this when the client is not managed by ``with`` (e.g., wired into a + DI container's lifecycle). Mirrors AsyncClient.aclose(). + """ + if self._owns_client and not self._httpx2_client.is_closed: + self._httpx2_client.close() + + @typing.overload + def send(self, request: httpx2.Request, *, response_model: None = None) -> httpx2.Response: ... + + @typing.overload + def send(self, request: httpx2.Request, *, response_model: type[T]) -> T: ... + + def send( + self, + request: httpx2.Request, + *, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send `request` through the middleware chain. Decode if `response_model` is set.""" + response = self._dispatch(request) + if response_model is None: + return response + return self._decoder.decode(response.content, response_model) + + def build_request(self, method: str, url: str, **kwargs: typing.Any) -> httpx2.Request: + """Delegate request construction to the wrapped httpx2.Client.""" + return self._httpx2_client.build_request(method, url, **kwargs) + + def _request_with_body( # noqa: PLR0913, C901 — mirrors httpx2 per-method signatures; kwargs-forwarding complexity is structural + self, + method: str, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + kwargs: dict[str, typing.Any] = {} + if params is not None: + kwargs["params"] = params + if headers is not None: + kwargs["headers"] = headers + if cookies is not None: + kwargs["cookies"] = cookies + if timeout is not httpx2.USE_CLIENT_DEFAULT: + kwargs["timeout"] = timeout + if extensions is not None: + kwargs["extensions"] = extensions + if json is not None: + kwargs["json"] = json + if content is not None: + kwargs["content"] = content + if data is not None: + kwargs["data"] = data + if files is not None: + kwargs["files"] = files + request = self._httpx2_client.build_request(method, url, **kwargs) + if _is_streaming_body_sync(content) or _is_streaming_body_sync(data) or _is_streaming_body_sync(files): + request.extensions[STREAMING_BODY_MARKER] = True + return self.send(request, response_model=response_model) + + @typing.overload + def get( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def get( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def get( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a GET request.""" + return self._request_with_body( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + response_model=response_model, + ) + + @typing.overload + def post( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def post( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def post( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a POST request.""" + return self._request_with_body( + "POST", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + json=json, + content=content, + data=data, + files=files, + response_model=response_model, + ) + + @typing.overload + def put( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def put( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def put( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a PUT request.""" + return self._request_with_body( + "PUT", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + json=json, + content=content, + data=data, + files=files, + response_model=response_model, + ) + + @typing.overload + def patch( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def patch( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def patch( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a PATCH request.""" + return self._request_with_body( + "PATCH", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + json=json, + content=content, + data=data, + files=files, + response_model=response_model, + ) + + @typing.overload + def delete( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def delete( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def delete( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a DELETE request.""" + return self._request_with_body( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + json=json, + content=content, + data=data, + files=files, + response_model=response_model, + ) + + @typing.overload + def head( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def head( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def head( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a HEAD request.""" + return self._request_with_body( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + response_model=response_model, + ) + + @typing.overload + def options( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def options( + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def options( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send an OPTIONS request.""" + return self._request_with_body( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + response_model=response_model, + ) + + @typing.overload + def request( + self, + method: str, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: None = None, + ) -> httpx2.Response: ... + + @typing.overload + def request( + self, + method: str, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T], + ) -> T: ... + + def request( # noqa: PLR0913 — mirrors httpx2 per-method signatures + self, + method: str, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + response_model: type[T] | None = None, + ) -> httpx2.Response | T: + """Send a request with an arbitrary HTTP method.""" + return self._request_with_body( + method, + url, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + json=json, + content=content, + data=data, + files=files, + response_model=response_model, + ) + + @contextlib.contextmanager + def stream( # noqa: PLR0913, C901 — mirrors httpx2 per-method signatures; kwargs-forwarding complexity is structural + self, + method: str, + url: str, + *, + params: typing.Any | None = None, + headers: typing.Any | None = None, + cookies: typing.Any | None = None, + timeout: typing.Any = httpx2.USE_CLIENT_DEFAULT, + extensions: typing.Any | None = None, + json: typing.Any | None = None, + content: typing.Any | None = None, + data: typing.Any | None = None, + files: typing.Any | None = None, + ) -> Iterator[httpx2.Response]: + """Stream an HTTP response. Bypasses the middleware chain. + + Yields an httpx2.Response; consume the body via response.iter_bytes(), + response.iter_text(), response.iter_lines(), or response.iter_raw(). + The body is NOT pre-read for 2xx/3xx (streaming preserved); the response + is closed when the context exits. + + Bypasses the middleware chain (no Retry, no Bulkhead, no user-installed + middleware) — matches AsyncClient.stream() behavior. + + Auto-raises StatusError subclasses on 4xx/5xx. On error the response + body is pre-read so exc.response.content is accessible. + + Maps httpx2 exceptions raised during the request OR body consumption to + httpware exceptions via _httpx2_exception_mapper_sync. + """ + kwargs: dict[str, typing.Any] = {} + if params is not None: + kwargs["params"] = params + if headers is not None: + kwargs["headers"] = headers + if cookies is not None: + kwargs["cookies"] = cookies + if timeout is not httpx2.USE_CLIENT_DEFAULT: + kwargs["timeout"] = timeout + if extensions is not None: + kwargs["extensions"] = extensions + if json is not None: + kwargs["json"] = json + if content is not None: + kwargs["content"] = content + if data is not None: + kwargs["data"] = data + if files is not None: + kwargs["files"] = files + + with _httpx2_exception_mapper_sync(), self._httpx2_client.stream(method, url, **kwargs) as response: + if HTTPStatus.BAD_REQUEST <= response.status_code < 600: # noqa: PLR2004 — 600 is the synthetic upper bound for 5xx + response.read() # pre-read body so exc.response.content works + _raise_on_status_error(response) + yield response diff --git a/src/httpware/middleware/__init__.py b/src/httpware/middleware/__init__.py index 8eb7eed..4920854 100644 --- a/src/httpware/middleware/__init__.py +++ b/src/httpware/middleware/__init__.py @@ -1,8 +1,8 @@ -"""AsyncMiddleware protocol, AsyncNext type, and phase-shortcut decorators. +"""Middleware + AsyncMiddleware protocols, Next + AsyncNext types, and phase-shortcut decorators. -AsyncMiddleware operates directly on httpx2.Request / httpx2.Response — there is +Middleware operates directly on httpx2.Request / httpx2.Response — there is no httpware-owned request type. The chain is composed at AsyncClient.__init__ -(see client.py) and frozen for the client's lifetime. +or Client.__init__ (see client.py) and frozen for the client's lifetime. """ from collections.abc import Awaitable, Callable @@ -75,3 +75,69 @@ def __repr__(self) -> str: return f"" # ty: ignore[unresolved-attribute] return _OnErrorMiddleware() + + +Next: TypeAlias = Callable[[httpx2.Request], httpx2.Response] + + +@runtime_checkable +class Middleware(Protocol): + """Structural protocol every sync middleware satisfies.""" + + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + """Process `request`; call `next(request)` to forward, or synthesize a Response.""" + ... + + +def before_request(f: Callable[[httpx2.Request], httpx2.Request]) -> Middleware: + """Wrap a sync request transform into a Middleware.""" + + class _BeforeRequestMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + return next(f(request)) + + def __repr__(self) -> str: + return f"" # ty: ignore[unresolved-attribute] + + return _BeforeRequestMiddleware() + + +def after_response( + f: Callable[[httpx2.Request, httpx2.Response], httpx2.Response], +) -> Middleware: + """Wrap a sync response transform into a Middleware.""" + + class _AfterResponseMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + response = next(request) + return f(request, response) + + def __repr__(self) -> str: + return f"" # ty: ignore[unresolved-attribute] + + return _AfterResponseMiddleware() + + +def on_error( + f: Callable[[httpx2.Request, Exception], httpx2.Response | None], +) -> Middleware: + """Wrap a sync error handler into a Middleware. + + Catches Exception (not BaseException), so KeyboardInterrupt / SystemExit propagate. + Handler returning None re-raises; returning a Response replaces the failure. + """ + + class _OnErrorMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + try: + return next(request) + except Exception as exc: + result = f(request, exc) + if result is None: + raise + return result + + def __repr__(self) -> str: + return f"" # ty: ignore[unresolved-attribute] + + return _OnErrorMiddleware() diff --git a/src/httpware/middleware/chain.py b/src/httpware/middleware/chain.py index 281a5c0..6ed5533 100644 --- a/src/httpware/middleware/chain.py +++ b/src/httpware/middleware/chain.py @@ -7,10 +7,11 @@ if typing.TYPE_CHECKING: - from httpware.middleware import AsyncMiddleware + from httpware.middleware import AsyncMiddleware, Middleware _AsyncNext: typing.TypeAlias = Callable[[httpx2.Request], Awaitable[httpx2.Response]] +_Next: typing.TypeAlias = Callable[[httpx2.Request], httpx2.Response] def compose_async(middleware: "Sequence[AsyncMiddleware]", terminal: _AsyncNext) -> _AsyncNext: @@ -29,3 +30,21 @@ async def call(request: httpx2.Request) -> httpx2.Response: return await layer(request, inner) return call + + +def compose(middleware: "Sequence[Middleware]", terminal: _Next) -> _Next: + """Fold sync `middleware` into a single callable around sync `terminal`. + + The first middleware in the sequence is the outermost wrapper. + """ + dispatch: _Next = terminal + for layer in reversed(middleware): + dispatch = _wrap_sync(layer, dispatch) + return dispatch + + +def _wrap_sync(layer: "Middleware", inner: _Next) -> _Next: + def call(request: httpx2.Request) -> httpx2.Response: + return layer(request, inner) + + return call diff --git a/src/httpware/middleware/resilience/__init__.py b/src/httpware/middleware/resilience/__init__.py index 79c0c4f..0c7d5ce 100644 --- a/src/httpware/middleware/resilience/__init__.py +++ b/src/httpware/middleware/resilience/__init__.py @@ -1,8 +1,8 @@ -"""Resilience primitives: AsyncBulkhead, AsyncRetry middleware, and RetryBudget token bucket.""" +"""Resilience primitives: Bulkhead/AsyncBulkhead, Retry/AsyncRetry, RetryBudget.""" from httpware.middleware.resilience.budget import RetryBudget -from httpware.middleware.resilience.bulkhead import AsyncBulkhead -from httpware.middleware.resilience.retry import AsyncRetry +from httpware.middleware.resilience.bulkhead import AsyncBulkhead, Bulkhead +from httpware.middleware.resilience.retry import AsyncRetry, Retry -__all__ = ["AsyncBulkhead", "AsyncRetry", "RetryBudget"] +__all__ = ["AsyncBulkhead", "AsyncRetry", "Bulkhead", "Retry", "RetryBudget"] diff --git a/src/httpware/middleware/resilience/budget.py b/src/httpware/middleware/resilience/budget.py index 16c8be9..0e49ce1 100644 --- a/src/httpware/middleware/resilience/budget.py +++ b/src/httpware/middleware/resilience/budget.py @@ -1,11 +1,14 @@ """Finagle-style token-bucket retry budget. See planning/specs/2026-06-05-retry-and-retry-budget-design.md for the contract. -No locking: asyncio runs coroutines cooperatively on a single thread, so deque -mutations between await points are atomic with respect to other coroutines on -the same event loop. Cross-thread use is out of scope. + +Thread-safe and asyncio-safe: all mutations go through a threading.Lock. +A single RetryBudget instance is safe to share across threads, across +coroutines on one event loop, and across (sync Client, AsyncClient) pairs +in the same process. """ +import threading import time from collections import deque from collections.abc import Callable @@ -31,10 +34,12 @@ def __init__( self._min_retries_per_sec = min_retries_per_sec self._percent_can_retry = percent_can_retry self._now = _now + self._lock = threading.Lock() self._deposits: deque[float] = deque() self._withdrawn: deque[float] = deque() def _purge(self, now: float) -> None: + # Caller must hold self._lock. # Strict `< cutoff` keeps entries at exactly `now - ttl`: window is [now - ttl, now]. cutoff = now - self._ttl while self._deposits and self._deposits[0] < cutoff: @@ -45,8 +50,9 @@ def _purge(self, now: float) -> None: def deposit(self) -> None: """Record a request (success or failure attempt). Adds one token.""" now = self._now() - self._purge(now) - self._deposits.append(now) + with self._lock: + self._purge(now) + self._deposits.append(now) def try_withdraw(self) -> bool: """Atomically attempt to spend one retry token. @@ -55,10 +61,11 @@ def try_withdraw(self) -> bool: Never blocks. """ now = self._now() - self._purge(now) - floor = int(self._min_retries_per_sec * self._ttl) - ceiling = int(len(self._deposits) * self._percent_can_retry) + floor - if len(self._withdrawn) >= ceiling: - return False - self._withdrawn.append(now) - return True + with self._lock: + self._purge(now) + floor = int(self._min_retries_per_sec * self._ttl) + ceiling = int(len(self._deposits) * self._percent_can_retry) + floor + if len(self._withdrawn) >= ceiling: + return False + self._withdrawn.append(now) + return True diff --git a/src/httpware/middleware/resilience/bulkhead.py b/src/httpware/middleware/resilience/bulkhead.py index a33e6dd..d723d62 100644 --- a/src/httpware/middleware/resilience/bulkhead.py +++ b/src/httpware/middleware/resilience/bulkhead.py @@ -13,12 +13,13 @@ import asyncio import logging +import threading import httpx2 from httpware._internal.observability import _emit_event from httpware.errors import BulkheadFullError -from httpware.middleware import AsyncNext +from httpware.middleware import AsyncNext, Next _MAX_CONCURRENT_INVALID = "max_concurrent must be >= 1" @@ -89,3 +90,59 @@ async def __call__(self, request: httpx2.Request, next: AsyncNext) -> httpx2.Res return await next(request) finally: self._sem.release() + + +class Bulkhead: + """Sync concurrency limiter backed by threading.Semaphore. + + Bulkhead is the sharable unit — pass the same instance to multiple + Client(middleware=[shared]) calls to enforce a joint cap across clients. + + Bulkhead is per-world: a single instance cannot be shared between a Client + and an AsyncClient (the underlying semaphore primitives differ). To cap + a sync+async mixed workload, use a Bulkhead and an AsyncBulkhead with + matching max_concurrent. + """ + + def __init__( + self, + *, + max_concurrent: int, + acquire_timeout: float | None = 1.0, + ) -> None: + if max_concurrent < 1: + raise ValueError(_MAX_CONCURRENT_INVALID) + if acquire_timeout is not None and acquire_timeout < 0: + raise ValueError(_ACQUIRE_TIMEOUT_INVALID) + self._max_concurrent = max_concurrent + self._acquire_timeout = acquire_timeout + self._sem = threading.Semaphore(max_concurrent) + + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + """Acquire a slot (bounded by acquire_timeout), invoke next, release.""" + # threading.Semaphore.acquire(timeout=None) blocks until acquired; + # acquire(timeout=0) returns immediately (True if a slot was available, + # False otherwise). Both match AsyncBulkhead's contract. + acquired = self._sem.acquire(timeout=self._acquire_timeout) + if not acquired: + _emit_event( + _LOGGER, + "bulkhead.rejected", + level=logging.WARNING, + message="bulkhead rejected request — acquire_timeout exceeded", + attributes={ + "max_concurrent": self._max_concurrent, + "acquire_timeout": self._acquire_timeout, + "method": request.method, + "url": str(request.url), + }, + ) + raise BulkheadFullError( + max_concurrent=self._max_concurrent, + acquire_timeout=self._acquire_timeout, + ) + + try: + return next(request) + finally: + self._sem.release() diff --git a/src/httpware/middleware/resilience/retry.py b/src/httpware/middleware/resilience/retry.py index 344ff0b..8387308 100644 --- a/src/httpware/middleware/resilience/retry.py +++ b/src/httpware/middleware/resilience/retry.py @@ -11,15 +11,16 @@ import datetime import email.utils import logging +import time from collections.abc import Awaitable, Callable from http import HTTPStatus import httpx2 from httpware._internal.observability import _emit_event -from httpware.client import STREAMING_BODY_MARKER +from httpware._internal.status import STREAMING_BODY_MARKER from httpware.errors import NetworkError, RetryBudgetExhaustedError, StatusError, TimeoutError # noqa: A004 -from httpware.middleware import AsyncNext +from httpware.middleware import AsyncNext, Next from httpware.middleware.resilience._backoff import full_jitter_delay from httpware.middleware.resilience.budget import RetryBudget @@ -196,3 +197,134 @@ async def __call__(self, request: httpx2.Request, next: AsyncNext) -> httpx2.Res msg = "unreachable" # pragma: no cover raise AssertionError(msg) # pragma: no cover + + +class Retry: + """Sync retry middleware. Mirror of AsyncRetry; uses time.sleep instead of asyncio.sleep.""" + + def __init__( # noqa: PLR0913 — retry policy has many orthogonal knobs; a dataclass would be worse + self, + *, + max_attempts: int = 3, + base_delay: float = 0.1, + max_delay: float = 5.0, + retry_status_codes: frozenset[int] = DEFAULT_RETRY_STATUS_CODES, + retry_methods: frozenset[str] = DEFAULT_IDEMPOTENT_METHODS, + respect_retry_after: bool = True, + budget: RetryBudget | None = None, + _sleep: Callable[[float], None] = time.sleep, + ) -> None: + if max_attempts < 1: + raise ValueError(_MAX_ATTEMPTS_INVALID) + self.max_attempts = max_attempts + self.base_delay = base_delay + self.max_delay = max_delay + self.retry_status_codes = retry_status_codes + self.retry_methods = retry_methods + self.respect_retry_after = respect_retry_after + self.budget = budget if budget is not None else RetryBudget() + self._sleep = _sleep + + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002, C901, PLR0912, PLR0915 — same complexity rationale as AsyncRetry + """Process a request through the sync retry loop. See AsyncRetry for full contract.""" + method_eligible = request.method.upper() in self.retry_methods + last_exc: BaseException | None = None + last_response: httpx2.Response | None = None + + for attempt in range(self.max_attempts): + is_last = attempt + 1 >= self.max_attempts + self.budget.deposit() + try: + return next(request) + except StatusError as exc: + retryable_status = exc.response.status_code in self.retry_status_codes + if not method_eligible or not retryable_status: + if retryable_status and request.extensions.get(STREAMING_BODY_MARKER): + exc.add_note(_STREAMING_BODY_REFUSAL_NOTE) + raise + last_exc = exc + last_response = exc.response + except (NetworkError, TimeoutError) as exc: + if not method_eligible: + if request.extensions.get(STREAMING_BODY_MARKER): + exc.add_note(_STREAMING_BODY_REFUSAL_NOTE) + raise + last_exc = exc + last_response = None + + # ---- retryable failure path + if request.extensions.get(STREAMING_BODY_MARKER): + if last_exc is None: # pragma: no cover — invariant from except branch + msg = "Retry: streaming-body refusal reached with no last_exc" + raise AssertionError(msg) + last_exc.add_note(_STREAMING_BODY_REFUSAL_NOTE) + _emit_event( + _LOGGER, + "retry.streaming_refused", + level=logging.WARNING, + message="retry refused — request body is a stream that cannot replay", + attributes={ + "method": request.method, + "url": str(request.url), + "last_exception_type": type(last_exc).__qualname__, + }, + ) + raise last_exc + + if is_last: + if last_exc is None: # pragma: no cover — structural invariant from except branch + msg = "Retry: last_exc unset on final attempt — unreachable" + raise AssertionError(msg) + last_exc.add_note(f"httpware: gave up after {attempt + 1} attempts") + _emit_event( + _LOGGER, + "retry.giving_up", + level=logging.WARNING, + message=f"retry gave up after {attempt + 1} attempts", + attributes={ + "attempts": attempt + 1, + "method": request.method, + "url": str(request.url), + "last_status": last_response.status_code if last_response is not None else None, + "last_exception_type": type(last_exc).__qualname__, + }, + ) + raise last_exc + + if not self.budget.try_withdraw(): + _emit_event( + _LOGGER, + "retry.budget_refused", + level=logging.WARNING, + message=f"retry budget refused after {attempt + 1} attempts", + attributes={ + "attempts": attempt + 1, + "method": request.method, + "url": str(request.url), + "last_status": last_response.status_code if last_response is not None else None, + }, + ) + raise RetryBudgetExhaustedError( + last_response=last_response, + last_exception=last_exc, + attempts=attempt + 1, + ) from last_exc + + retry_after: float | None = None + if self.respect_retry_after and last_response is not None: + header = last_response.headers.get("Retry-After") + if header is not None: + retry_after = _parse_retry_after(header) + + if retry_after is not None: + delay = min(retry_after, self.max_delay) + else: + delay = full_jitter_delay( + attempt, + base_delay=self.base_delay, + max_delay=self.max_delay, + ) + self._sleep(delay) + + msg = "unreachable" # pragma: no cover + raise AssertionError(msg) # pragma: no cover diff --git a/tests/test_bulkhead_sync.py b/tests/test_bulkhead_sync.py new file mode 100644 index 0000000..c427e67 --- /dev/null +++ b/tests/test_bulkhead_sync.py @@ -0,0 +1,184 @@ +"""Tests for the sync Bulkhead middleware. + +Mirror of test_bulkhead.py for sync semantics. Uses threading for the +concurrency-cap proofs. +""" + +import logging +import threading +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from http import HTTPStatus + +import httpx2 +import pytest + +from httpware import Client +from httpware.errors import BulkheadFullError +from httpware.middleware.resilience.bulkhead import Bulkhead + + +_MAX_CONCURRENT_1 = 1 +_MAX_CONCURRENT_2 = 2 +_ACQUIRE_TIMEOUT_FAST = 0.01 +_ACQUIRE_TIMEOUT_SHORT = 0.05 +_ACQUIRE_TIMEOUT_LONG = 0.5 + + +class _SlowHandler: + """Mock handler that blocks for ``delay`` seconds before returning 200 OK.""" + + def __init__(self, delay: float) -> None: + self.delay = delay + self.lock = threading.Lock() + self.in_flight = 0 + self.max_in_flight = 0 + self.calls = 0 + + def __call__(self, request: httpx2.Request) -> httpx2.Response: + with self.lock: + self.calls += 1 + self.in_flight += 1 + self.max_in_flight = max(self.max_in_flight, self.in_flight) + try: + time.sleep(self.delay) + return httpx2.Response(HTTPStatus.OK, request=request) + finally: + with self.lock: + self.in_flight -= 1 + + +def _client( + handler: Callable[[httpx2.Request], httpx2.Response], + *, + bulkhead: Bulkhead, +) -> Client: + transport = httpx2.MockTransport(handler) + return Client( + httpx2_client=httpx2.Client(transport=transport), + middleware=[bulkhead], + ) + + +def test_max_concurrent_zero_rejected() -> None: + with pytest.raises(ValueError, match="max_concurrent must be >= 1"): + Bulkhead(max_concurrent=0) + + +def test_max_concurrent_negative_rejected() -> None: + with pytest.raises(ValueError, match="max_concurrent must be >= 1"): + Bulkhead(max_concurrent=-1) + + +def test_negative_acquire_timeout_rejected() -> None: + with pytest.raises(ValueError, match="acquire_timeout must be >= 0"): + Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=-0.1) + + +def test_acquire_timeout_zero_accepted() -> None: + bulkhead = Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=0) + assert bulkhead._acquire_timeout == 0 # noqa: SLF001 + + +def test_acquire_timeout_none_accepted() -> None: + bulkhead = Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=None) + assert bulkhead._acquire_timeout is None # noqa: SLF001 + + +def test_succeeds_when_slot_available() -> None: + handler = _SlowHandler(delay=0.0) + client = _client(handler, bulkhead=Bulkhead(max_concurrent=_MAX_CONCURRENT_2)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert handler.calls == 1 + + +def test_serializes_at_capacity() -> None: + """With max_concurrent=1 and 3 concurrent threads, in-flight count never exceeds 1.""" + handler = _SlowHandler(delay=0.02) + client = _client( + handler, + bulkhead=Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=None), + ) + with ThreadPoolExecutor(max_workers=3) as ex: + futures = [ex.submit(client.get, f"https://example.test/{i}") for i in "abc"] + for f in futures: + f.result() + assert handler.calls == 3 # noqa: PLR2004 + assert handler.max_in_flight == 1 + + +def test_acquire_timeout_rejects_when_no_slot_available() -> None: + handler = _SlowHandler(delay=0.1) + client = _client( + handler, + bulkhead=Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=_ACQUIRE_TIMEOUT_FAST), + ) + + holder = threading.Thread(target=client.get, args=("https://example.test/hold",)) + holder.start() + # Give the holder time to acquire the only slot + time.sleep(0.01) + try: + with pytest.raises(BulkheadFullError) as info: + client.get("https://example.test/blocked") + assert info.value.max_concurrent == _MAX_CONCURRENT_1 + assert info.value.acquire_timeout == _ACQUIRE_TIMEOUT_FAST + finally: + holder.join() + + +def test_releases_slot_on_exception() -> None: + """A handler that raises must still cause the slot to be released.""" + calls = [] + + def boom(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + calls.append(1) + msg = "kaboom" + raise RuntimeError(msg) + + transport = httpx2.MockTransport(boom) + bulkhead = Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=_ACQUIRE_TIMEOUT_SHORT) + client = Client(httpx2_client=httpx2.Client(transport=transport), middleware=[bulkhead]) + + with pytest.raises(RuntimeError, match="kaboom"): + client.get("https://example.test/x") + # Second call must succeed (slot was released) — handler still raises, but bulkhead doesn't reject + with pytest.raises(RuntimeError, match="kaboom"): + client.get("https://example.test/y") + assert len(calls) == 2 # noqa: PLR2004 — both attempts reached the handler + + +def test_emits_rejected_event(caplog: pytest.LogCaptureFixture) -> None: + handler = _SlowHandler(delay=0.1) + client = _client( + handler, + bulkhead=Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=_ACQUIRE_TIMEOUT_FAST), + ) + holder = threading.Thread(target=client.get, args=("https://example.test/hold",)) + holder.start() + time.sleep(0.01) + try: + with caplog.at_level(logging.WARNING, logger="httpware.bulkhead"), pytest.raises(BulkheadFullError): + client.get("https://example.test/blocked") + assert any("bulkhead rejected" in r.getMessage() for r in caplog.records) + finally: + holder.join() + + +def test_acquire_timeout_none_blocks_until_slot_available() -> None: + """With acquire_timeout=None, the call should block until a slot frees up.""" + handler = _SlowHandler(delay=0.05) + client = _client( + handler, + bulkhead=Bulkhead(max_concurrent=_MAX_CONCURRENT_1, acquire_timeout=None), + ) + holder = threading.Thread(target=client.get, args=("https://example.test/hold",)) + holder.start() + time.sleep(0.005) # ensure holder has the slot + # This should not raise; it should wait for the slot. + response = client.get("https://example.test/wait") + holder.join() + assert response.status_code == HTTPStatus.OK + assert handler.calls == 2 # noqa: PLR2004 diff --git a/tests/test_client_stream_sync.py b/tests/test_client_stream_sync.py new file mode 100644 index 0000000..85319a5 --- /dev/null +++ b/tests/test_client_stream_sync.py @@ -0,0 +1,306 @@ +"""Tests for Client.stream() — sync sibling of test_client_stream.py.""" + +import typing +from http import HTTPStatus + +import httpx2 +import pytest + +from httpware import ( + Client, + ClientStatusError, + NetworkError, + NotFoundError, + ServerStatusError, + ServiceUnavailableError, + TransportError, +) +from httpware import ( + TimeoutError as HttpwareTimeoutError, +) +from httpware.middleware import Middleware, Next + + +_UNKNOWN_4XX = 418 # I'm a teapot +_UNKNOWN_5XX = 599 +_REDIRECT_3XX = 301 +_NOT_FOUND = 404 +_SERVICE_UNAVAILABLE = 503 + + +def _client(handler: typing.Callable[[httpx2.Request], httpx2.Response]) -> Client: + transport = httpx2.MockTransport(handler) + return Client(httpx2_client=httpx2.Client(transport=transport)) + + +def test_streams_response_body_successfully() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, content=b"chunk1chunk2chunk3") + + client = _client(handler) + with client.stream("GET", "https://example.test/x") as response: + assert response.status_code == HTTPStatus.OK + chunks = list(response.iter_bytes()) + assert b"".join(chunks) == b"chunk1chunk2chunk3" + + +def test_auto_raises_on_4xx_with_body_preread() -> None: + body = b'{"error": "not found"}' + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(_NOT_FOUND, request=request, content=body) + + client = _client(handler) + with pytest.raises(NotFoundError) as info, client.stream("GET", "https://example.test/missing"): + pytest.fail("should have raised before reaching block body") # pragma: no cover + assert info.value.response.status_code == _NOT_FOUND + assert info.value.response.content == body # body was pre-read; accessible + + +def test_auto_raises_on_5xx_with_body_preread() -> None: + body = b"degraded" + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(_SERVICE_UNAVAILABLE, request=request, content=body) + + client = _client(handler) + with pytest.raises(ServiceUnavailableError) as info, client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + assert info.value.response.content == body + + +def test_auto_raises_unknown_4xx_falls_back_to_client_status_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(_UNKNOWN_4XX, request=request) + + client = _client(handler) + with pytest.raises(ClientStatusError) as info, client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + assert type(info.value) is ClientStatusError + assert info.value.response.status_code == _UNKNOWN_4XX + + +def test_auto_raises_unknown_5xx_falls_back_to_server_status_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(_UNKNOWN_5XX, request=request) + + client = _client(handler) + with pytest.raises(ServerStatusError) as info, client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + assert type(info.value) is ServerStatusError + assert info.value.response.status_code == _UNKNOWN_5XX + + +def test_3xx_does_not_raise() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(_REDIRECT_3XX, request=request, headers={"location": "/y"}) + + client = _client(handler) + with client.stream("GET", "https://example.test/x") as response: + assert response.status_code == _REDIRECT_3XX + + +def test_network_error_during_request_maps_to_network_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "connect refused" + raise httpx2.ConnectError(msg) + + client = _client(handler) + with pytest.raises(NetworkError, match="connect refused"), client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + + +def test_network_error_during_body_consumption_maps_to_network_error() -> None: + def streaming_body() -> typing.Iterator[bytes]: + yield b"first chunk" + msg = "read failed mid-stream" + raise httpx2.ReadError(msg) + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, content=streaming_body()) + + client = _client(handler) + + def consume() -> None: + with client.stream("GET", "https://example.test/x") as response: + for _ in response.iter_bytes(): + pass + + with pytest.raises(NetworkError, match="read failed mid-stream"): + consume() + + +def test_timeout_during_stream_maps_to_httpware_timeout() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "read timeout" + raise httpx2.ReadTimeout(msg) + + client = _client(handler) + with pytest.raises(HttpwareTimeoutError, match="read timeout"), client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + + +def test_invalid_url_maps_to_bare_transport_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "bad url" + raise httpx2.InvalidURL(msg) + + client = _client(handler) + with pytest.raises(TransportError) as info, client.stream("GET", "https://example.test/x"): + pytest.fail("unreachable") # pragma: no cover + assert not isinstance(info.value, NetworkError) + + +def test_user_exception_in_block_propagates_unchanged() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, content=b"data") + + client = _client(handler) + + def trigger() -> None: + with client.stream("GET", "https://example.test/x"): + msg = "user explosion" + raise ValueError(msg) + + with pytest.raises(ValueError, match="user explosion"): + trigger() + + +def test_bypasses_middleware_chain() -> None: + """stream() must not invoke any middleware in the chain.""" + invocations = {"n": 0} + + class _RecordingMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 # pragma: no cover + invocations["n"] += 1 + return next(request) + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, content=b"x") + + transport = httpx2.MockTransport(handler) + middleware: Middleware = _RecordingMiddleware() + client = Client( + httpx2_client=httpx2.Client(transport=transport), + middleware=[middleware], + ) + + with client.stream("GET", "https://example.test/x") as response: + for _ in response.iter_bytes(): + pass + + assert invocations["n"] == 0 + + +def test_forwards_kwargs_to_httpx2() -> None: + seen: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen.append(request) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"") + + client = _client(handler) + with client.stream( + "GET", + "https://example.test/x", + params={"q": "value"}, + headers={"X-Custom": "1"}, + cookies={"sid": "abc"}, + ) as response: + _ = list(response.iter_bytes()) + + request = seen[0] + assert request.url.params["q"] == "value" + assert request.headers["x-custom"] == "1" + assert request.headers["cookie"] == "sid=abc" + + +def test_stream_with_content_kwarg() -> None: + seen: list[bytes] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen.append(request.content) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"") + + client = _client(handler) + with client.stream("POST", "https://example.test/upload", content=b"payload") as response: + _ = list(response.iter_bytes()) + + assert seen[0] == b"payload" + + +def test_stream_with_sync_iterable_content() -> None: + """stream() bypass means sync-iterable bodies work without the streaming-body marker mechanism.""" + seen_calls: list[int] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen_calls.append(1) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"") + + def streamed_body() -> typing.Iterator[bytes]: + yield b"chunk1" + yield b"chunk2" + + client = _client(handler) + with client.stream("POST", "https://example.test/upload", content=streamed_body()) as response: + _ = list(response.iter_bytes()) + + assert seen_calls == [1] + + +def test_stream_with_timeout_kwarg() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, content=b"ok") + + client = _client(handler) + with client.stream("GET", "https://example.test/x", timeout=5.0) as response: + _ = list(response.iter_bytes()) + assert response.status_code == HTTPStatus.OK + + +def test_stream_with_json_kwarg() -> None: + seen: list[bytes] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen.append(request.content) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"ok") + + client = _client(handler) + with client.stream("POST", "https://example.test/x", json={"key": "value"}) as response: + _ = list(response.iter_bytes()) + assert b"key" in seen[0] + + +def test_stream_with_data_and_extensions_kwargs() -> None: + seen: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen.append(request) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"ok") + + client = _client(handler) + with client.stream( + "POST", + "https://example.test/x", + data={"field": "val"}, + extensions={"timeout": {"connect": 5}}, + ) as response: + _ = list(response.iter_bytes()) + assert seen[0].headers["content-type"].startswith("application/x-www-form-urlencoded") + + +def test_stream_with_files_kwarg() -> None: + seen: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen.append(request) + return httpx2.Response(HTTPStatus.OK, request=request, content=b"ok") + + client = _client(handler) + with client.stream( + "POST", + "https://example.test/x", + files={"upload": ("hello.txt", b"hello", "text/plain")}, + ) as response: + _ = list(response.iter_bytes()) + assert "multipart/form-data" in seen[0].headers["content-type"] diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py new file mode 100644 index 0000000..e270f85 --- /dev/null +++ b/tests/test_client_sync.py @@ -0,0 +1,341 @@ +"""Tests for the sync Client — construction, methods, lifecycle, error mapping.""" + +from http import HTTPStatus + +import httpx2 +import pydantic +import pytest + +from httpware import Client, NotFoundError +from httpware.decoders.pydantic import PydanticDecoder +from httpware.errors import TransportError + + +# ---------- Construction ---------- + + +def test_construction_with_no_args_works() -> None: + client = Client() + assert isinstance(client, Client) + client.close() + + +def test_construction_with_forwarded_kwargs() -> None: + client = Client( + base_url="https://example.test", + headers={"x-shared": "1"}, + params={"trace": "yes"}, + timeout=10.0, + ) + assert isinstance(client, Client) + client.close() + + +def test_construction_with_caller_owned_httpx2_client() -> None: + transport = httpx2.MockTransport(lambda req: httpx2.Response(200, request=req)) + caller = httpx2.Client(transport=transport) + client = Client(httpx2_client=caller) + assert isinstance(client, Client) + caller.close() + + +@pytest.mark.parametrize( + "kwargs", + [ + {"base_url": "https://example.test"}, + {"headers": {"x": "1"}}, + {"params": {"x": "1"}}, + {"cookies": {"x": "1"}}, + {"timeout": 5.0}, + {"limits": httpx2.Limits(max_connections=10)}, + {"auth": httpx2.BasicAuth("u", "p")}, + ], +) +def test_caller_owned_client_with_forwarded_kwargs_is_typeerror(kwargs: dict) -> None: + transport = httpx2.MockTransport(lambda req: httpx2.Response(200, request=req)) + caller = httpx2.Client(transport=transport) + with pytest.raises(TypeError, match="httpx2_client"): + Client(httpx2_client=caller, **kwargs) + caller.close() + + +def test_default_decoder_is_pydantic_decoder() -> None: + client = Client() + assert isinstance(client._decoder, PydanticDecoder) # noqa: SLF001 + client.close() + + +def test_explicit_decoder_is_honored() -> None: + class _Stub: + def decode(self, content: bytes, model: type) -> object: # noqa: ARG002 # pragma: no cover + return None + + client = Client(decoder=_Stub()) + assert isinstance(client._decoder, _Stub) # noqa: SLF001 + client.close() + + +@pytest.mark.parametrize( + "kwargs", + [ + {"cookies": {"session": "abc"}}, + {"limits": httpx2.Limits(max_connections=5)}, + {"auth": httpx2.BasicAuth("user", "pass")}, + ], +) +def test_construction_with_optional_forwarded_kwargs(kwargs: dict) -> None: + """Exercises cookies/limits/auth branches in __init__ when no httpx2_client is supplied.""" + client = Client(**kwargs) + assert isinstance(client, Client) + client.close() + + +def test_explicit_middleware_is_honored() -> None: + class _Tag: + def __call__(self, request, next) -> httpx2.Response: # noqa: A002, ANN001 # pragma: no cover + return next(request) + + client = Client(middleware=(_Tag(),)) + assert len(client._user_middleware) == 1 # noqa: SLF001 + client.close() + + +# ---------- Methods ---------- + + +def _echo_handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response( + HTTPStatus.OK, + request=request, + json={ + "method": request.method, + "url": str(request.url), + "headers": dict(request.headers), + "content": request.content.decode() if request.content else "", + }, + ) + + +def _client_with_handler(handler, **kwargs) -> Client: # noqa: ANN001, ANN003 + transport = httpx2.MockTransport(handler) + return Client(httpx2_client=httpx2.Client(transport=transport, **kwargs)) + + +def test_get_returns_httpx2_response() -> None: + client = _client_with_handler(_echo_handler) + response = client.get("https://example.test/x") + assert isinstance(response, httpx2.Response) + assert response.json()["method"] == "GET" + + +@pytest.mark.parametrize( + "method_name", + ["get", "post", "put", "patch", "delete", "head", "options"], +) +def test_each_per_method_helper_uses_correct_verb(method_name: str) -> None: + client = _client_with_handler(_echo_handler) + method = getattr(client, method_name) + response = method("https://example.test/x") + assert response.json()["method"] == method_name.upper() + + +def test_post_json_body_serialized() -> None: + client = _client_with_handler(_echo_handler) + response = client.post("https://example.test/x", json={"k": "v"}) + payload = response.json() + assert "application/json" in payload["headers"]["content-type"] + assert payload["content"] == '{"k":"v"}' + + +def test_get_with_params_forwards_query() -> None: + captured: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + captured.append(request) + return httpx2.Response(HTTPStatus.OK, request=request) + + client = _client_with_handler(handler) + client.get("https://example.test/x", params={"a": "1"}) + assert "a=1" in str(captured[0].url) + + +def test_get_with_headers_merges() -> None: + captured: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + captured.append(request) + return httpx2.Response(HTTPStatus.OK, request=request) + + client = _client_with_handler(handler) + client.get("https://example.test/x", headers={"x-trace": "abc"}) + assert captured[0].headers["x-trace"] == "abc" + + +def test_get_raises_typed_status_error_on_404() -> None: + client = _client_with_handler(lambda req: httpx2.Response(HTTPStatus.NOT_FOUND, request=req)) + with pytest.raises(NotFoundError): + client.get("https://example.test/missing") + + +def test_request_method_takes_arbitrary_verb() -> None: + client = _client_with_handler(_echo_handler) + response = client.request("PROPFIND", "https://example.test/x") + assert response.json()["method"] == "PROPFIND" + + +def test_base_url_is_applied() -> None: + captured: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + captured.append(request) + return httpx2.Response(HTTPStatus.OK, request=request) + + transport = httpx2.MockTransport(handler) + underlying = httpx2.Client(transport=transport, base_url="https://example.test") + client = Client(httpx2_client=underlying) + client.get("/relative") + assert str(captured[0].url) == "https://example.test/relative" + + +def test_get_with_cookies_forwarded() -> None: + """Exercises the cookies branch in _request_with_body.""" + captured: list[httpx2.Request] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + captured.append(request) + return httpx2.Response(HTTPStatus.OK, request=request) + + client = _client_with_handler(handler) + client.get("https://example.test/x", cookies={"token": "abc"}) + assert "token=abc" in captured[0].headers.get("cookie", "") + + +def test_get_with_explicit_timeout() -> None: + """Exercises the timeout branch in _request_with_body.""" + client = _client_with_handler(_echo_handler) + response = client.get("https://example.test/x", timeout=5.0) + assert response.status_code == HTTPStatus.OK + + +def test_get_with_extensions() -> None: + """Exercises the extensions branch in _request_with_body.""" + client = _client_with_handler(_echo_handler) + response = client.get("https://example.test/x", extensions={"trace": True}) + assert response.status_code == HTTPStatus.OK + + +def test_post_with_content_body() -> None: + """Exercises the content branch in _request_with_body.""" + client = _client_with_handler(_echo_handler) + response = client.post("https://example.test/x", content=b"raw-bytes") + assert response.json()["content"] == "raw-bytes" + + +def test_post_with_data_body() -> None: + """Exercises the data branch in _request_with_body.""" + client = _client_with_handler(_echo_handler) + response = client.post("https://example.test/x", data={"field": "value"}) + assert response.status_code == HTTPStatus.OK + + +def test_post_with_files_body() -> None: + """Exercises the files branch in _request_with_body.""" + client = _client_with_handler(_echo_handler) + response = client.post("https://example.test/x", files={"upload": b"file-content"}) + assert response.status_code == HTTPStatus.OK + + +def test_runtime_error_without_closed_reraises() -> None: + """Exercises the RuntimeError re-raise branch in _terminal (error not containing 'closed').""" + + def boom(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "unexpected internal failure" + raise RuntimeError(msg) + + client = _client_with_handler(boom) + with pytest.raises(RuntimeError, match="unexpected internal failure"): + client.get("https://example.test/x") + + +def test_terminal_runtime_error_with_closed_maps_to_transport_error() -> None: + """A RuntimeError mentioning 'closed' should be remapped to TransportError.""" + transport = httpx2.MockTransport(lambda req: httpx2.Response(HTTPStatus.OK, request=req)) + underlying = httpx2.Client(transport=transport) + client = Client(httpx2_client=underlying) + underlying.close() + with pytest.raises(TransportError): + client.get("https://example.test/x") + + +def test_send_with_response_model_decodes() -> None: + """Exercises the response_model decode path in send().""" + + class _User(pydantic.BaseModel): + id: int + name: str + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.OK, request=request, json={"id": 1, "name": "alice"}) + + client = _client_with_handler(handler) + user = client.get("https://example.test/u", response_model=_User) + assert isinstance(user, _User) + assert user.id == 1 + assert user.name == "alice" + + +def test_build_request_delegates_to_underlying() -> None: + client = _client_with_handler(_echo_handler) + req = client.build_request("GET", "https://example.test/x") + assert isinstance(req, httpx2.Request) + assert req.method == "GET" + + +# ---------- Lifecycle ---------- + + +def test_exit_closes_owned_httpx2_client() -> None: + client = Client() + with client: + pass + assert client._httpx2_client.is_closed # noqa: SLF001 + + +def test_exit_does_not_close_borrowed_httpx2_client() -> None: + transport = httpx2.MockTransport(lambda req: httpx2.Response(HTTPStatus.OK, request=req)) + underlying = httpx2.Client(transport=transport) + client = Client(httpx2_client=underlying) + with client: + pass + assert not underlying.is_closed + underlying.close() + + +def test_exit_is_idempotent_for_owned_client() -> None: + client = Client() + with client: + pass + # Second use should not raise + client.__exit__(None, None, None) + + +def test_close_closes_owned_httpx2_client() -> None: + client = Client() + client.close() + assert client._httpx2_client.is_closed # noqa: SLF001 + + +def test_close_is_idempotent_for_owned_client() -> None: + client = Client() + client.close() + client.close() + assert client._httpx2_client.is_closed # noqa: SLF001 + + +def test_close_does_not_close_borrowed_httpx2_client() -> None: + transport = httpx2.MockTransport(lambda req: httpx2.Response(HTTPStatus.OK, request=req)) + underlying = httpx2.Client(transport=transport) + client = Client(httpx2_client=underlying) + client.close() + assert not underlying.is_closed + underlying.close() diff --git a/tests/test_middleware_sync.py b/tests/test_middleware_sync.py new file mode 100644 index 0000000..26cc789 --- /dev/null +++ b/tests/test_middleware_sync.py @@ -0,0 +1,158 @@ +"""Tests for the sync Middleware protocol, Next type, chain composition, and decorators.""" + +from http import HTTPStatus + +import httpx2 +import pytest + +from httpware.middleware import ( + Middleware, + Next, + after_response, + before_request, + on_error, +) +from httpware.middleware.chain import compose + + +def _make_request(url: str = "https://example.test/x") -> httpx2.Request: + return httpx2.Request("GET", url) + + +def _make_response(status: int = HTTPStatus.OK, *, request: httpx2.Request | None = None) -> httpx2.Response: + if request is None: # pragma: no cover + request = _make_request() + return httpx2.Response(status, request=request) + + +def test_middleware_protocol_is_runtime_checkable() -> None: + class _OkMiddleware: + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 # pragma: no cover + return next(request) + + assert isinstance(_OkMiddleware(), Middleware) + + +def test_empty_chain_calls_terminal_directly() -> None: + seen: list[httpx2.Request] = [] + + def terminal(request: httpx2.Request) -> httpx2.Response: + seen.append(request) + return _make_response(200, request=request) + + dispatch = compose((), terminal) + request = _make_request() + response = dispatch(request) + assert response.status_code == HTTPStatus.OK + assert seen == [request] + + +def test_chain_runs_middleware_in_order() -> None: + order: list[str] = [] + + class _M: + def __init__(self, label: str) -> None: + self.label = label + + def __call__(self, request: httpx2.Request, next: Next) -> httpx2.Response: # noqa: A002 + order.append(f"{self.label}.before") + response = next(request) + order.append(f"{self.label}.after") + return response + + def terminal(request: httpx2.Request) -> httpx2.Response: + order.append("terminal") + return _make_response(200, request=request) + + dispatch = compose((_M("a"), _M("b")), terminal) + dispatch(_make_request()) + assert order == ["a.before", "b.before", "terminal", "b.after", "a.after"] + + +def test_before_request_decorator_transforms_request() -> None: + @before_request + def add_header(request: httpx2.Request) -> httpx2.Request: + return httpx2.Request(request.method, request.url, headers={**request.headers, "X-Custom": "1"}) + + captured: list[httpx2.Request] = [] + + def terminal(request: httpx2.Request) -> httpx2.Response: + captured.append(request) + return _make_response(200, request=request) + + dispatch = compose((add_header,), terminal) + dispatch(_make_request()) + assert captured[0].headers["x-custom"] == "1" + + +def test_after_response_decorator_transforms_response() -> None: + @after_response + def upgrade_status(request: httpx2.Request, response: httpx2.Response) -> httpx2.Response: + return httpx2.Response(HTTPStatus.IM_USED, request=request, headers=response.headers, content=response.content) + + def terminal(request: httpx2.Request) -> httpx2.Response: + return _make_response(HTTPStatus.OK, request=request) + + dispatch = compose((upgrade_status,), terminal) + response = dispatch(_make_request()) + assert response.status_code == HTTPStatus.IM_USED + + +def test_on_error_decorator_can_translate_exception() -> None: + @on_error + def swallow(request: httpx2.Request, exc: Exception) -> httpx2.Response | None: + if isinstance(exc, RuntimeError) and str(exc) == "boom": + return _make_response(HTTPStatus.SERVICE_UNAVAILABLE, request=request) + return None # pragma: no cover + + def terminal(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "boom" + raise RuntimeError(msg) + + dispatch = compose((swallow,), terminal) + response = dispatch(_make_request()) + assert response.status_code == HTTPStatus.SERVICE_UNAVAILABLE + + +def test_on_error_returns_none_reraises() -> None: + @on_error + def passthrough( + request: httpx2.Request, # noqa: ARG001 + exc: Exception, # noqa: ARG001 + ) -> httpx2.Response | None: + return None + + def terminal(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "boom" + raise RuntimeError(msg) + + dispatch = compose((passthrough,), terminal) + with pytest.raises(RuntimeError, match="boom"): + dispatch(_make_request()) + + +def test_before_request_repr() -> None: + @before_request + def my_transform(request: httpx2.Request) -> httpx2.Request: + return request # pragma: no cover + + assert "before_request" in repr(my_transform) + assert "my_transform" in repr(my_transform) + + +def test_after_response_repr() -> None: + @after_response + def my_transform(request: httpx2.Request, response: httpx2.Response) -> httpx2.Response: # noqa: ARG001 + return response # pragma: no cover + + assert "after_response" in repr(my_transform) + assert "my_transform" in repr(my_transform) + + +def test_on_error_repr() -> None: + @on_error + def my_handler(request: httpx2.Request, exc: Exception) -> httpx2.Response | None: # noqa: ARG001 + return None # pragma: no cover + + assert "on_error" in repr(my_handler) + assert "my_handler" in repr(my_handler) diff --git a/tests/test_optional_extras_pydantic_missing.py b/tests/test_optional_extras_pydantic_missing.py index ea1a896..4dab302 100644 --- a/tests/test_optional_extras_pydantic_missing.py +++ b/tests/test_optional_extras_pydantic_missing.py @@ -10,7 +10,7 @@ import pytest -from httpware import AsyncClient +from httpware import AsyncClient, Client from httpware.decoders.pydantic import PydanticDecoder @@ -30,6 +30,14 @@ def test_async_client_default_decoder_raises_when_pydantic_missing() -> None: AsyncClient() +def test_sync_client_default_decoder_raises_when_pydantic_missing() -> None: + with ( + patch("httpware._internal.import_checker.is_pydantic_installed", False), + pytest.raises(ImportError, match=r"httpware\[pydantic\]"), + ): + Client() + + def test_async_client_accepts_explicit_decoder_without_pydantic() -> None: """An explicit decoder= escapes the fail-fast even when pydantic is 'missing'.""" diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 70526a3..ba3d2fd 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -33,30 +33,38 @@ def test_expected_exports() -> None: "AsyncMiddleware", "AsyncNext", "AsyncRetry", + "BadRequestError", + "Bulkhead", "BulkheadFullError", - "NetworkError", - "ResponseDecoder", - "RetryBudget", - "RetryBudgetExhaustedError", + "Client", "ClientError", - "TransportError", - "TimeoutError", - "StatusError", "ClientStatusError", - "ServerStatusError", - "BadRequestError", - "UnauthorizedError", + "ConflictError", "ForbiddenError", + "InternalServerError", + "Middleware", + "NetworkError", + "Next", "NotFoundError", - "ConflictError", - "UnprocessableEntityError", "RateLimitedError", - "InternalServerError", + "ResponseDecoder", + "Retry", + "RetryBudget", + "RetryBudgetExhaustedError", + "ServerStatusError", "ServiceUnavailableError", "STATUS_TO_EXCEPTION", - "async_before_request", + "StatusError", + "TimeoutError", + "TransportError", + "UnauthorizedError", + "UnprocessableEntityError", + "after_response", "async_after_response", + "async_before_request", "async_on_error", + "before_request", + "on_error", } missing = expected - set(httpware.__all__) assert not missing, f"expected exports missing from __all__: {missing}" diff --git a/tests/test_retry.py b/tests/test_retry.py index 89f14a6..7d0da10 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -15,7 +15,7 @@ import pytest from httpware import AsyncClient, NotFoundError, ServiceUnavailableError, TransportError -from httpware.client import _is_streaming_body +from httpware._internal.status import _is_streaming_body_async as _is_streaming_body from httpware.errors import NetworkError, RetryBudgetExhaustedError from httpware.middleware.resilience.budget import RetryBudget from httpware.middleware.resilience.retry import ( diff --git a/tests/test_retry_budget_threadsafety.py b/tests/test_retry_budget_threadsafety.py new file mode 100644 index 0000000..1ebc489 --- /dev/null +++ b/tests/test_retry_budget_threadsafety.py @@ -0,0 +1,63 @@ +"""Thread-safety test for RetryBudget. + +Sync Client may share a RetryBudget across a ThreadPoolExecutor. Concurrent +deposit() / try_withdraw() calls must not corrupt the internal deques. We +spawn many threads doing many ops and assert no exception, sane counters. +""" + +import threading + +from httpware.middleware.resilience.budget import RetryBudget + + +_N_THREADS = 16 +_N_OPS_PER_THREAD = 1000 + + +def test_concurrent_deposit_withdraw_does_not_corrupt() -> None: + budget = RetryBudget(ttl=60.0, min_retries_per_sec=1000.0, percent_can_retry=0.5) + errors: list[BaseException] = [] + barrier = threading.Barrier(_N_THREADS) + + def worker() -> None: + try: + barrier.wait() + for _ in range(_N_OPS_PER_THREAD): + budget.deposit() + budget.try_withdraw() + except BaseException as exc: # noqa: BLE001 — collect any failure for the assert # pragma: no cover — defensive harness; passes mean this branch is not taken + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(_N_THREADS)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + # Each thread did _N_OPS_PER_THREAD deposits; budget must have accepted them all + # (and possibly some withdrawals — we don't assert withdrawn count; the ceiling + # formula doesn't guarantee how many succeed). + assert len(budget._deposits) <= _N_THREADS * _N_OPS_PER_THREAD # noqa: SLF001 — internal state check + assert len(budget._deposits) > 0 # noqa: SLF001 + + +def test_concurrent_only_deposit_count_matches() -> None: + budget = RetryBudget(ttl=60.0) + barrier = threading.Barrier(_N_THREADS) + + def worker() -> None: + barrier.wait() + for _ in range(_N_OPS_PER_THREAD): + budget.deposit() + + threads = [threading.Thread(target=worker) for _ in range(_N_THREADS)] + for t in threads: + t.start() + for t in threads: + t.join() + + # With no withdraws and no TTL expiry (60s window, sub-second test), every + # deposit lands in the deque. Exact equality proves no deposits were lost + # to a race. + assert len(budget._deposits) == _N_THREADS * _N_OPS_PER_THREAD # noqa: SLF001 diff --git a/tests/test_retry_sync.py b/tests/test_retry_sync.py new file mode 100644 index 0000000..fba802b --- /dev/null +++ b/tests/test_retry_sync.py @@ -0,0 +1,479 @@ +"""Tests for the sync Retry middleware. + +Mirror of test_retry.py. Mocks the transport via httpx2.MockTransport; +injects a recording ``_sleep`` callable so the suite runs instantly. +""" + +import datetime +import email.utils +import logging +import typing +from collections.abc import Callable +from http import HTTPStatus + +import httpx2 +import pytest + +from httpware import Client, NotFoundError, ServiceUnavailableError +from httpware._internal.status import STREAMING_BODY_MARKER, _is_streaming_body_sync +from httpware.errors import NetworkError, RetryBudgetExhaustedError, StatusError, TransportError +from httpware.errors import TimeoutError as HttpwareTimeoutError +from httpware.middleware.resilience.budget import RetryBudget +from httpware.middleware.resilience.retry import ( + DEFAULT_IDEMPOTENT_METHODS, + DEFAULT_RETRY_STATUS_CODES, + Retry, +) + + +class _SleepRecorder: + def __init__(self) -> None: + self.calls: list[float] = [] + + def __call__(self, delay: float) -> None: + self.calls.append(delay) + + +class _ResponseSequence: + def __init__(self, statuses: list[int]) -> None: + self._statuses = list(statuses) + self.calls: int = 0 + + def __call__(self, request: httpx2.Request) -> httpx2.Response: + self.calls += 1 + status = self._statuses.pop(0) if self._statuses else HTTPStatus.OK + return httpx2.Response(status, request=request) + + +class _ResponseSequenceWithHeaders: + """Mock handler that returns (status, headers) tuples in sequence.""" + + def __init__(self, responses: list[tuple[int, dict[str, str]]]) -> None: + self._responses = list(responses) + self.calls = 0 + + def __call__(self, request: httpx2.Request) -> httpx2.Response: + self.calls += 1 + status, headers = self._responses.pop(0) + return httpx2.Response(status, request=request, headers=headers) + + +def _client(handler: Callable[[httpx2.Request], httpx2.Response], *, retry: Retry) -> Client: + transport = httpx2.MockTransport(handler) + return Client( + httpx2_client=httpx2.Client(transport=transport), + middleware=[retry], + ) + + +def _zero_budget() -> RetryBudget: + """Return a budget that always refuses withdrawal (floor=0, percent=0).""" + return RetryBudget(ttl=10.0, min_retries_per_sec=0.0, percent_can_retry=0.0) + + +def test_default_retry_status_codes_match_spec() -> None: + # Module-level constant is shared with AsyncRetry; this test mirrors test_retry.py. + assert frozenset({408, 429, 502, 503, 504}) == DEFAULT_RETRY_STATUS_CODES + + +def test_default_idempotent_methods_match_spec() -> None: + assert frozenset({"GET", "HEAD", "OPTIONS", "PUT", "DELETE"}) == DEFAULT_IDEMPOTENT_METHODS + + +def test_succeeds_first_try_no_sleep() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.OK]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert handler.calls == 1 + assert sleeper.calls == [] + + +def test_retries_503_then_succeeds() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.OK]) + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.02)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert handler.calls == 2 # noqa: PLR2004 + assert len(sleeper.calls) == 1 + assert 0.0 <= sleeper.calls[0] <= 0.02 # noqa: PLR2004 + + +def test_gives_up_after_max_attempts_and_reraises_status_error() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE] * 3) + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.02, max_attempts=3)) + with pytest.raises(ServiceUnavailableError) as info: + client.get("https://example.test/x") + assert handler.calls == 3 # noqa: PLR2004 + assert len(sleeper.calls) == 2 # noqa: PLR2004 + notes = getattr(info.value, "__notes__", []) + assert any("gave up after 3 attempts" in note for note in notes) + + +def test_does_not_retry_non_retryable_status() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.NOT_FOUND]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + with pytest.raises(NotFoundError): + client.get("https://example.test/missing") + assert handler.calls == 1 + assert sleeper.calls == [] + + +def test_does_not_retry_non_idempotent_method() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + with pytest.raises(ServiceUnavailableError): + client.post("https://example.test/x") # POST is not idempotent by default + assert handler.calls == 1 + + +def test_max_attempts_one_means_no_retries() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE]) + client = _client(handler, retry=Retry(_sleep=sleeper, max_attempts=1)) + with pytest.raises(ServiceUnavailableError): + client.get("https://example.test/x") + assert handler.calls == 1 + assert sleeper.calls == [] + + +def test_max_attempts_zero_rejected() -> None: + with pytest.raises(ValueError, match="max_attempts must be >= 1"): + Retry(max_attempts=0) + + +def test_streamed_body_request_is_refused() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + + # Manually craft a request with the streaming-body marker set. + request = httpx2.Request("GET", "https://example.test/x") + request.extensions[STREAMING_BODY_MARKER] = True + + with pytest.raises(ServiceUnavailableError) as info: + client.send(request) + + notes = getattr(info.value, "__notes__", []) + assert any("stream that cannot replay" in note for note in notes) + assert sleeper.calls == [] # no retry attempted; no backoff + + +def test_streaming_body_refusal_emits_log_event(caplog: pytest.LogCaptureFixture) -> None: + """Cover the streaming-body refusal _emit_event branch in sync Retry.""" + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + + request = httpx2.Request("GET", "https://example.test/x") + request.extensions[STREAMING_BODY_MARKER] = True + + with caplog.at_level(logging.WARNING, logger="httpware.retry"), pytest.raises(ServiceUnavailableError): + client.send(request) + assert any("retry refused" in r.getMessage() for r in caplog.records) + + +def test_streaming_body_refusal_on_non_idempotent_method() -> None: + """Streaming-body marker added to exception even when method isn't idempotent.""" + sleeper = _SleepRecorder() + + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "transient" + raise httpx2.ConnectError(msg) + + client = _client(handler, retry=Retry(_sleep=sleeper)) + request = client.build_request("POST", "https://example.test/x") + request.extensions[STREAMING_BODY_MARKER] = True + with pytest.raises(NetworkError) as info: + client.send(request) + notes = getattr(info.value, "__notes__", []) + assert any("stream that cannot replay" in note for note in notes) + + +def test_streaming_body_refusal_status_error_on_non_idempotent_method() -> None: + """Status-error path: non-idempotent + retryable status + streaming marker -> note added.""" + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE]) + client = _client(handler, retry=Retry(_sleep=sleeper)) + request = client.build_request("POST", "https://example.test/x") + request.extensions[STREAMING_BODY_MARKER] = True + with pytest.raises(ServiceUnavailableError) as info: + client.send(request) + notes = getattr(info.value, "__notes__", []) + assert any("stream that cannot replay" in note for note in notes) + + +def test_client_post_with_sync_generator_content_marks_extensions() -> None: + """Posting with a sync generator body sets the streaming marker on request.extensions.""" + seen_extensions: list[dict[str, object]] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen_extensions.append(dict(request.extensions)) + return httpx2.Response(HTTPStatus.OK, request=request) + + def streamed_body() -> typing.Iterator[bytes]: + yield b"chunk1" + yield b"chunk2" + + transport = httpx2.MockTransport(handler) + client = Client(httpx2_client=httpx2.Client(transport=transport)) + client.post("https://example.test/upload", content=streamed_body()) + + assert len(seen_extensions) == 1 + assert seen_extensions[0].get(STREAMING_BODY_MARKER) is True + + +def test_client_post_with_list_content_does_not_mark_extensions() -> None: + """A list body is replayable; should NOT be marked as streaming.""" + seen_extensions: list[dict[str, object]] = [] + + def handler(request: httpx2.Request) -> httpx2.Response: + seen_extensions.append(dict(request.extensions)) + return httpx2.Response(HTTPStatus.OK, request=request) + + transport = httpx2.MockTransport(handler) + client = Client(httpx2_client=httpx2.Client(transport=transport)) + client.post("https://example.test/upload", content=[b"chunk1", b"chunk2"]) + + assert len(seen_extensions) == 1 + assert STREAMING_BODY_MARKER not in seen_extensions[0] + + +def test_budget_exhausted_raises_with_payload() -> None: + sleeper = _SleepRecorder() + # Tiny budget: 0 floor, 0 retries. + budget = RetryBudget(ttl=10.0, min_retries_per_sec=0.0, percent_can_retry=0.0) + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.OK]) + client = _client(handler, retry=Retry(_sleep=sleeper, budget=budget, max_attempts=3)) + with pytest.raises(RetryBudgetExhaustedError) as info: + client.get("https://example.test/x") + assert info.value.attempts == 1 + assert info.value.last_response is not None + assert info.value.last_response.status_code == HTTPStatus.SERVICE_UNAVAILABLE + + +def test_budget_exhausted_on_network_error_carries_exception_not_response() -> None: + sleeper = _SleepRecorder() + + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "transient" + raise httpx2.ConnectError(msg) + + client = _client( + handler, + retry=Retry(_sleep=sleeper, budget=_zero_budget(), base_delay=0.01, max_delay=0.02), + ) + with pytest.raises(RetryBudgetExhaustedError) as info: + client.get("https://example.test/x") + assert info.value.last_response is None + assert isinstance(info.value.last_exception, NetworkError) + + +def test_retry_after_seconds_honored() -> None: + sleeper = _SleepRecorder() + + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response( + HTTPStatus.TOO_MANY_REQUESTS, + request=request, + headers={"Retry-After": "1"}, + ) + + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.5, max_attempts=2)) + with pytest.raises(StatusError): + client.get("https://example.test/x") + # Retry-After=1 clamped to max_delay=0.5 + assert sleeper.calls == [0.5] + + +def test_retry_after_http_date_overrides_backoff() -> None: + sleeper = _SleepRecorder() + future = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=3) + http_date = email.utils.format_datetime(future, usegmt=True) + handler = _ResponseSequenceWithHeaders( + [ + (HTTPStatus.SERVICE_UNAVAILABLE, {"Retry-After": http_date}), + (HTTPStatus.OK, {}), + ] + ) + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=10.0)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert len(sleeper.calls) == 1 + assert 2.0 <= sleeper.calls[0] <= 4.0 # noqa: PLR2004 + + +def test_malformed_retry_after_falls_back_to_backoff() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequenceWithHeaders( + [ + (HTTPStatus.SERVICE_UNAVAILABLE, {"Retry-After": "not-a-number"}), + (HTTPStatus.OK, {}), + ] + ) + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.05)) + client.get("https://example.test/x") + assert len(sleeper.calls) == 1 + assert 0.0 <= sleeper.calls[0] <= 0.05 # noqa: PLR2004 + + +def test_respect_retry_after_false_ignores_header() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequenceWithHeaders( + [ + (HTTPStatus.SERVICE_UNAVAILABLE, {"Retry-After": "5"}), + (HTTPStatus.OK, {}), + ] + ) + client = _client( + handler, + retry=Retry(_sleep=sleeper, respect_retry_after=False, base_delay=0.01, max_delay=0.02), + ) + client.get("https://example.test/x") + assert len(sleeper.calls) == 1 + assert 0.0 <= sleeper.calls[0] <= 0.02 # noqa: PLR2004 + + +def test_retries_on_network_error() -> None: + sleeper = _SleepRecorder() + call_count = {"n": 0} + + def handler(request: httpx2.Request) -> httpx2.Response: + call_count["n"] += 1 + if call_count["n"] < 2: # noqa: PLR2004 + msg = "transient" + raise httpx2.ConnectError(msg) + return httpx2.Response(HTTPStatus.OK, request=request) + + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.02)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert call_count["n"] == 2 # noqa: PLR2004 + assert len(sleeper.calls) == 1 + + +def test_retries_on_httpware_timeout_error() -> None: + sleeper = _SleepRecorder() + call_count = {"n": 0} + + def handler(request: httpx2.Request) -> httpx2.Response: + call_count["n"] += 1 + if call_count["n"] < 2: # noqa: PLR2004 + msg = "read timeout" + raise httpx2.ReadTimeout(msg) + return httpx2.Response(HTTPStatus.OK, request=request) + + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_delay=0.02)) + response = client.get("https://example.test/x") + assert response.status_code == HTTPStatus.OK + assert call_count["n"] == 2 # noqa: PLR2004 + assert isinstance(HttpwareTimeoutError("x"), HttpwareTimeoutError) # type smoke + + +def test_does_not_retry_on_bare_transport_error_like_invalid_url() -> None: + sleeper = _SleepRecorder() + + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "bad url" + raise httpx2.InvalidURL(msg) + + client = _client(handler, retry=Retry(_sleep=sleeper)) + with pytest.raises(TransportError) as info: + client.get("https://example.test/x") + assert not isinstance(info.value, NetworkError) + assert sleeper.calls == [] + + +def test_network_error_exhaustion_reraises_with_note() -> None: + sleeper = _SleepRecorder() + + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + msg = "never works" + raise httpx2.ConnectError(msg) + + client = _client(handler, retry=Retry(_sleep=sleeper, max_attempts=2, base_delay=0.01, max_delay=0.02)) + with pytest.raises(NetworkError) as info: + client.get("https://example.test/x") + notes = getattr(info.value, "__notes__", []) + assert any("gave up after 2 attempts" in note for note in notes) + + +def test_does_not_retry_network_error_on_non_idempotent_method() -> None: + sleeper = _SleepRecorder() + call_count = {"n": 0} + + def handler(request: httpx2.Request) -> httpx2.Response: # noqa: ARG001 + call_count["n"] += 1 + msg = "transient" + raise httpx2.ConnectError(msg) + + client = _client(handler, retry=Retry(_sleep=sleeper)) + with pytest.raises(NetworkError): + client.post("https://example.test/x", json={"x": 1}) + assert call_count["n"] == 1 + assert sleeper.calls == [] + + +def test_retries_post_when_method_explicitly_included() -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.OK]) + client = _client( + handler, + retry=Retry( + _sleep=sleeper, + base_delay=0.01, + max_delay=0.02, + retry_methods=frozenset({"GET", "POST"}), + ), + ) + response = client.post("https://example.test/x", json={"k": "v"}) + assert response.status_code == HTTPStatus.OK + assert handler.calls == 2 # noqa: PLR2004 + assert len(sleeper.calls) == 1 + + +def test_default_budget_is_fresh_per_instance() -> None: + r1 = Retry() + r2 = Retry() + assert r1.budget is not r2.budget + + +def test_explicit_budget_shared_across_retry_instances() -> None: + shared = RetryBudget(ttl=10.0, min_retries_per_sec=1.0, percent_can_retry=0.0) + r1 = Retry(budget=shared) + r2 = Retry(budget=shared) + assert r1.budget is r2.budget + + +def test_emits_giving_up_log_event(caplog: pytest.LogCaptureFixture) -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE] * 2) + client = _client(handler, retry=Retry(_sleep=sleeper, base_delay=0.01, max_attempts=2)) + with caplog.at_level(logging.WARNING, logger="httpware.retry"), pytest.raises(ServiceUnavailableError): + client.get("https://example.test/x") + assert any("retry gave up" in r.getMessage() for r in caplog.records) + + +def test_emits_budget_refused_log_event(caplog: pytest.LogCaptureFixture) -> None: + sleeper = _SleepRecorder() + handler = _ResponseSequence([HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.OK]) + client = _client(handler, retry=Retry(_sleep=sleeper, budget=_zero_budget(), max_attempts=3)) + with caplog.at_level(logging.WARNING, logger="httpware.retry"), pytest.raises(RetryBudgetExhaustedError): + client.get("https://example.test/x") + assert any("budget refused" in r.getMessage() for r in caplog.records) + + +def test_is_streaming_body_sync_predicates() -> None: + assert _is_streaming_body_sync(None) is False + assert _is_streaming_body_sync(b"bytes") is False + assert _is_streaming_body_sync("str") is False + assert _is_streaming_body_sync({"k": "v"}) is False + assert _is_streaming_body_sync([1, 2]) is False + assert _is_streaming_body_sync((1, 2)) is False + assert _is_streaming_body_sync(iter([1, 2])) is True + assert _is_streaming_body_sync(x for x in range(3)) is True # generator diff --git a/tests/test_threading_with_shared_budget.py b/tests/test_threading_with_shared_budget.py new file mode 100644 index 0000000..4800396 --- /dev/null +++ b/tests/test_threading_with_shared_budget.py @@ -0,0 +1,79 @@ +"""Demonstrates that a single RetryBudget can be shared across a sync Client and an AsyncClient. + +The lock added in Task B2 makes ``RetryBudget`` thread-safe so sync threads and an asyncio +event loop can deposit/withdraw concurrently without corrupting the internal deques. +""" + +import asyncio +import contextlib +import threading +from http import HTTPStatus + +import httpx2 + +from httpware import AsyncClient, AsyncRetry, Client, Retry +from httpware.middleware.resilience.budget import RetryBudget + + +_N_SYNC_THREADS = 4 +_N_OPS_PER_THREAD = 50 +_N_ASYNC_TASKS = 20 + + +def _failing_handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(HTTPStatus.SERVICE_UNAVAILABLE, request=request) + + +def _sync_worker(sync_client: Client) -> None: + for _ in range(_N_OPS_PER_THREAD): + with contextlib.suppress(Exception): + sync_client.get("https://example.test/x") + + +async def _safe_get(async_client: AsyncClient) -> None: + with contextlib.suppress(Exception): + await async_client.get("https://example.test/x") + + +async def _drive_async_side(budget: RetryBudget) -> None: + transport = httpx2.MockTransport(_failing_handler) + async_client = AsyncClient( + httpx2_client=httpx2.AsyncClient(transport=transport), + middleware=[ + AsyncRetry( + budget=budget, + max_attempts=2, + base_delay=0.0001, + max_delay=0.001, + _sleep=asyncio.sleep, + ), + ], + ) + async with async_client: + await asyncio.gather(*[_safe_get(async_client) for _ in range(_N_ASYNC_TASKS)]) + + +def test_shared_budget_across_sync_threads_and_async_loop() -> None: + budget = RetryBudget(ttl=60.0, min_retries_per_sec=1000.0, percent_can_retry=0.5) + + sync_transport = httpx2.MockTransport(_failing_handler) + sync_client = Client( + httpx2_client=httpx2.Client(transport=sync_transport), + middleware=[Retry(budget=budget, max_attempts=2, base_delay=0.0001, max_delay=0.001)], + ) + + threads = [threading.Thread(target=_sync_worker, args=(sync_client,)) for _ in range(_N_SYNC_THREADS)] + for t in threads: + t.start() + + asyncio.run(_drive_async_side(budget)) + + for t in threads: + t.join() + + # The lock kept the budget's internal deques consistent — no IndexError, no corruption. + # No specific count assertion: the test passes if it completes without an exception + # from the budget itself. Add a smoke check that the budget recorded SOME activity: + assert len(budget._deposits) > 0 # noqa: SLF001 + + sync_client.close()