Skip to content
1,006 changes: 1,006 additions & 0 deletions docs/superpowers/plans/2026-06-01-auth-coercion-plan.md

Large diffs are not rendered by default.

372 changes: 372 additions & 0 deletions docs/superpowers/specs/2026-06-01-auth-coercion-design.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/httpware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""httpware — resilience-first async HTTP client framework for Python."""

from httpware._internal.auth import AuthValue
from httpware.client import AsyncClient
from httpware.config import ClientConfig, Limits, Timeout
from httpware.decoders import ResponseDecoder
Expand Down Expand Up @@ -33,6 +34,7 @@
__all__ = [
"STATUS_TO_EXCEPTION",
"AsyncClient",
"AuthValue",
"BadRequestError",
"ClientConfig",
"ClientError",
Expand Down
75 changes: 75 additions & 0 deletions src/httpware/_internal/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Normalize the `auth=` value of AsyncClient into a Middleware (or None)."""

import inspect
from collections.abc import Awaitable, Callable
from typing import TypeAlias

from httpware.middleware import Middleware, before_request
from httpware.request import Request


_MIDDLEWARE_ARITY = 2

AuthValue: TypeAlias = str | Callable[[], str | Awaitable[str]] | Middleware | None


def _normalize_auth(value: AuthValue) -> Middleware | None:
"""Coerce an `auth=` value into a Middleware.

- `None` → returns `None` (no auth middleware injected).
- `str` → returns a middleware that sets `Authorization: Bearer <str>`
on every request (skipping if Authorization is already present).
- `Callable[[], str | Awaitable[str]]` (zero-arg) → returns a middleware
that calls the provider per request (awaiting if it returns an
awaitable) and sets `Authorization: Bearer <result>` (skip-if-present).
- `Middleware` (two-arg `__call__(request, next)`) → returned unchanged.
- Any other callable shape → raises `TypeError` naming `auth=`.
"""
if value is None:
return None
if isinstance(value, str):
return _bearer(value)
if not callable(value):
msg = f"`auth=` must be a string, zero-arg callable, Middleware, or None; got {type(value).__name__}"
raise TypeError(msg)
n_params = len(inspect.signature(value).parameters)
if n_params == 0:
return _bearer_from_provider(value) # ty: ignore[invalid-argument-type]
if n_params == _MIDDLEWARE_ARITY:
return value # ty: ignore[invalid-return-type]
msg = f"`auth=` callable must take 0 args (token provider) or 2 args (Middleware); got {n_params}"
raise TypeError(msg)


def _bearer(token: str) -> Middleware:
"""Middleware that sets `Authorization: Bearer <token>` (skip-if-present)."""

@before_request
async def _add_static_bearer(request: Request) -> Request:
if _has_authorization(request):
return request
return request.with_header("Authorization", f"Bearer {token}")

return _add_static_bearer


def _bearer_from_provider(
provider: Callable[[], str | Awaitable[str]],
) -> Middleware:
"""Middleware that calls `provider()` per request and sets the header."""

@before_request
async def _add_dynamic_bearer(request: Request) -> Request:
if _has_authorization(request):
return request
token = provider()
if inspect.isawaitable(token):
token = await token
return request.with_header("Authorization", f"Bearer {token}")

return _add_dynamic_bearer


def _has_authorization(request: Request) -> bool:
"""Case-insensitive check for an existing Authorization header."""
return any(k.lower() == "authorization" for k in request.headers)
51 changes: 45 additions & 6 deletions src/httpware/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
from collections.abc import Mapping, Sequence

from httpware._internal.auth import AuthValue, _normalize_auth
from httpware._internal.chain import compose
from httpware.config import ClientConfig, Limits, Timeout
from httpware.decoders import ResponseDecoder
Expand Down Expand Up @@ -52,6 +53,8 @@ class AsyncClient:
_transport: Transport
_dispatch: Next
_owns_transport: bool
_user_middleware: tuple[Middleware, ...]
_auth: AuthValue

def __init__(
self,
Expand All @@ -64,12 +67,19 @@ def __init__(
transport: Transport | None = None,
decoder: ResponseDecoder | None = None,
middleware: Sequence[Middleware] | None = None,
auth: AuthValue = None,
) -> None:
normalized_timeout = _normalize_timeout(timeout)
resolved_limits = limits or Limits()
resolved_transport: Transport = transport or Httpx2Transport(limits=resolved_limits, timeout=normalized_timeout)
resolved_decoder = decoder or PydanticDecoder()
resolved_middleware = tuple(middleware) if middleware is not None else ()
resolved_user_middleware: tuple[Middleware, ...] = tuple(middleware) if middleware is not None else ()
resolved_auth_middleware = _normalize_auth(auth)
composed_middleware: tuple[Middleware, ...] = (
resolved_user_middleware
if resolved_auth_middleware is None
else (*resolved_user_middleware, resolved_auth_middleware)
)

self._config = ClientConfig(
base_url=base_url,
Expand All @@ -78,11 +88,13 @@ def __init__(
timeout=normalized_timeout,
limits=resolved_limits,
decoder=resolved_decoder,
middleware=resolved_middleware,
middleware=composed_middleware,
)
self._transport = resolved_transport
self._dispatch = compose(resolved_middleware, resolved_transport)
self._dispatch = compose(composed_middleware, resolved_transport)
self._owns_transport = True
self._user_middleware = resolved_user_middleware
self._auth = auth

@classmethod
def from_url(cls, base_url: str, **kwargs: object) -> "AsyncClient":
Expand Down Expand Up @@ -582,6 +594,7 @@ def with_options(
timeout: Timeout | float | None = _UNSET,
decoder: ResponseDecoder | None = _UNSET,
middleware: Sequence[Middleware] | None = _UNSET,
auth: AuthValue | object = _UNSET,
) -> "AsyncClient":
"""Return a new AsyncClient sharing the same transport with overridden config.

Expand All @@ -603,18 +616,44 @@ def with_options(
changes["timeout"] = _normalize_timeout(timeout)
if decoder is not _UNSET:
changes["decoder"] = decoder or PydanticDecoder()

new_user_middleware = self._user_middleware
if middleware is not _UNSET:
changes["middleware"] = tuple(middleware) if middleware is not None else ()
new_user_middleware = tuple(middleware) if middleware is not None else ()

new_auth: AuthValue = self._auth
if auth is not _UNSET:
new_auth = auth # ty: ignore[invalid-assignment]

new_auth_middleware = _normalize_auth(new_auth)
new_composed: tuple[Middleware, ...] = (
new_user_middleware if new_auth_middleware is None else (*new_user_middleware, new_auth_middleware)
)
changes["middleware"] = new_composed

new_config = dataclasses.replace(self._config, **changes)
return AsyncClient._from_view(new_config, self._transport)
return AsyncClient._from_view(
new_config,
self._transport,
user_middleware=new_user_middleware,
auth=new_auth,
)

@classmethod
def _from_view(cls, config: ClientConfig, transport: Transport) -> "AsyncClient":
def _from_view(
cls,
config: ClientConfig,
transport: Transport,
*,
user_middleware: tuple[Middleware, ...],
auth: AuthValue,
) -> "AsyncClient":
"""Construct a view sharing an existing transport. Bypasses __init__."""
client = cls.__new__(cls)
client._config = config # noqa: SLF001
client._transport = transport # noqa: SLF001
client._dispatch = compose(config.middleware, transport) # noqa: SLF001
client._owns_transport = False # noqa: SLF001
client._user_middleware = user_middleware # noqa: SLF001
client._auth = auth # noqa: SLF001
return client
34 changes: 34 additions & 0 deletions tests/test_client_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,37 @@ def test_construction_does_not_create_httpx2_client() -> None:
# Httpx2Transport stores `_client` lazily; until first call, _client is None.
# The attribute is private; we check it via getattr to keep the test resilient.
assert getattr(client._transport, "_client", "missing") is None


def test_init_no_auth_means_no_auth_middleware() -> None:
transport = RecordedTransport()
client = AsyncClient(transport=transport)
assert client._config.middleware == ()
assert client._auth is None
assert client._user_middleware == ()


def test_init_with_string_auth_appends_bearer_middleware() -> None:
transport = RecordedTransport()
client = AsyncClient(transport=transport, auth="tok")
assert len(client._config.middleware) == 1
assert isinstance(client._config.middleware[0], Middleware)
assert client._auth == "tok"
assert client._user_middleware == ()


def test_init_with_user_middleware_plus_auth() -> None:
class _M:
async def __call__(self, request, next) -> Response: # noqa: A002, ANN001
return await next(request)

m1 = _M()
m2 = _M()
transport = RecordedTransport()
client = AsyncClient(transport=transport, middleware=[m1, m2], auth="tok")
_expected_len = 3
assert len(client._config.middleware) == _expected_len
assert client._config.middleware[0] is m1
assert client._config.middleware[1] is m2
# The third entry is the auth middleware; identity-test that user_middleware excludes it.
assert client._user_middleware == (m1, m2)
37 changes: 37 additions & 0 deletions tests/test_client_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,40 @@ async def test_per_call_timeout_propagates_to_request_extensions() -> None:
await client.get("/foo", timeout=2.5)
assert transport.last_request is not None
assert "timeout" in transport.last_request.extensions


async def test_string_auth_sends_authorization_header() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))
client = AsyncClient(transport=transport, auth="tok")

await client.get("/foo")

assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer tok"


async def test_per_call_authorization_header_wins_over_auth_param() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))
client = AsyncClient(transport=transport, auth="default-tok")

await client.get("/foo", headers={"Authorization": "Bearer override"})

assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer override"


async def test_callable_auth_calls_provider_per_request() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))
calls = 0

def _provider() -> str:
nonlocal calls
calls += 1
return f"tok-{calls}"

client = AsyncClient(transport=transport, auth=_provider)

await client.get("/a")
await client.get("/b")

assert calls == 2 # noqa: PLR2004
53 changes: 53 additions & 0 deletions tests/test_client_middleware_wiring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for AsyncClient middleware wiring through compose() and with_options."""

from collections.abc import Mapping

from httpware import AsyncClient, RecordedTransport
from httpware.middleware import Middleware, Next
from httpware.request import Request
Expand Down Expand Up @@ -112,3 +114,54 @@ def decode(self, content: bytes, model: type) -> object: # pragma: no cover #
client = AsyncClient(transport=transport)
view = client.with_options(decoder=new_decoder)
assert view._config.decoder is new_decoder # noqa: SLF001


async def test_auth_runs_inside_user_middleware() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))

user_seen_headers: list[Mapping[str, str]] = []

class _UserOuter:
async def __call__(self, request: Request, next: Next) -> Response: # noqa: A002
user_seen_headers.append(dict(request.headers))
return await next(request)

client = AsyncClient(transport=transport, middleware=[_UserOuter()], auth="tok")
await client.get("/foo")

# User middleware saw the request BEFORE auth header was applied.
assert "Authorization" not in user_seen_headers[0]
# Transport saw the request WITH the auth header.
assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer tok"


async def test_with_options_auth_replaces_auth_middleware() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))
client = AsyncClient(transport=transport, auth="parent")
view = client.with_options(auth="view")

await view.get("/foo")
assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer view"

await client.get("/foo")
assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer parent"


async def test_with_options_middleware_keeps_existing_auth() -> None:
transport = RecordedTransport(default=Response(status=200, headers={}, content=b"", url="/", elapsed=0.0))

class _M:
async def __call__(self, request: Request, next: Next) -> Response: # noqa: A002
return await next(request)

m1 = _M()
m2 = _M()
client = AsyncClient(transport=transport, auth="tok", middleware=[m1])
view = client.with_options(middleware=[m2])

await view.get("/foo")
assert transport.last_request is not None
assert transport.last_request.headers["Authorization"] == "Bearer tok"
Loading
Loading