From d90c1a4ec7b9f07596539b3986bf45dfeefe4282 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 17 Jun 2026 01:55:15 -0700 Subject: [PATCH] feat(api): add Session runtime and provider readiness probes Make multi-turn agent loops first-class without re-creating a provider per call: - Session owns one provider instance across many interact()/stream()/ run_many() turns and closes it on exit; the one-shot helpers now delegate to a single-use Session so behavior stays identical. - ReadinessProvider protocol + ProviderReadiness let a provider expose a fast preflight. Session.check_ready() and the module-level check_ready() probe before packing context; the local provider implements it against /health and /v1/models (verifying the configured model when set). - local_reasoning() helper returns the scoped provider_options for local servers that honor enable_thinking. --- src/pollux/__init__.py | 307 +++++++++++++++---- src/pollux/providers/base.py | 21 ++ src/pollux/providers/local.py | 102 +++++- tests/interaction/test_interact_frontdoor.py | 51 ++- tests/providers/test_local_contract.py | 91 ++++++ 5 files changed, 504 insertions(+), 68 deletions(-) diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index 1808d06b..13c94c49 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -72,7 +72,11 @@ resolve_persistent_cache, stream_interaction, ) -from pollux.providers.base import CloseableProvider +from pollux.providers.base import ( + CloseableProvider, + ProviderReadiness, + ReadinessProvider, +) from pollux.retry import RetryPolicy from pollux.source import Source @@ -253,24 +257,20 @@ async def interact( config=cfg, ) """ - requirements = _build_requirements( - output=output, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - seed=seed, - reasoning_effort=reasoning_effort, - reasoning_budget_tokens=reasoning_budget_tokens, - tool_choice=tool_choice, - provider_options=provider_options, - ) - provider = _get_provider(config) - try: - return await execute_interaction( - environment, input, requirements, config, provider + async with Session(config) as session: + return await session.interact( + environment, + input, + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + provider_options=provider_options, ) - finally: - await _close_provider(provider) async def stream( @@ -325,27 +325,217 @@ async def stream( elif event.type == "done": result = event.output """ - requirements = _build_requirements( - output=output, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - seed=seed, - reasoning_effort=reasoning_effort, - reasoning_budget_tokens=reasoning_budget_tokens, - tool_choice=tool_choice, - provider_options=provider_options, - ) - provider = _get_provider(config) - try: + async with Session(config) as session: + async for event in session.stream( + environment, + input, + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + provider_options=provider_options, + ): + yield event + + +class Session: + """Reusable Pollux runtime for multi-turn agent loops. + + The one-shot helpers create and close a provider per call. ``Session`` owns a + single provider instance so clients with many sequential turns can reuse + transport resources while still going through the public interaction APIs. + """ + + def __init__(self, config: Config) -> None: + self.config = config + self._provider = _get_provider(config) + self._closed = False + + async def __aenter__(self) -> Session: # noqa: PYI034 + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc: BaseException | None, + _tb: object, + ) -> None: + """Close provider resources when leaving the async context manager.""" + await self.aclose() + + def _ensure_open(self) -> None: + if self._closed: + raise ConfigurationError( + "Pollux Session is closed", + hint="Create a new Session for additional interactions.", + ) + + async def interact( + self, + environment: Environment, + input: Input, # noqa: A002 - "input" is the canonical v2 primitive name + *, + output: ResponseSchemaInput | None = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, + seed: int | None = None, + reasoning_effort: str | None = None, + reasoning_budget_tokens: int | None = None, + tool_choice: ToolChoice | None = None, + provider_options: dict[str, dict[str, Any]] | None = None, + ) -> Output: + """Run one interaction using the session's provider instance.""" + self._ensure_open() + requirements = _build_requirements( + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + provider_options=provider_options, + ) + return await execute_interaction( + environment, input, requirements, self.config, self._provider + ) + + async def stream( + self, + environment: Environment, + input: Input, # noqa: A002 - "input" is the canonical v2 primitive name + *, + output: ResponseSchemaInput | None = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, + seed: int | None = None, + reasoning_effort: str | None = None, + reasoning_budget_tokens: int | None = None, + tool_choice: ToolChoice | None = None, + provider_options: dict[str, dict[str, Any]] | None = None, + ) -> AsyncIterator[Event]: + """Stream one interaction using the session's provider instance.""" + self._ensure_open() + requirements = _build_requirements( + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + provider_options=provider_options, + ) async for event in stream_interaction( - environment, input, requirements, config, provider + environment, input, requirements, self.config, self._provider ): yield event + + async def run_many( + self, + prompts: str | Sequence[str | None] | None = None, + *, + sources: Sequence[Source] = (), + environment: Environment | None = None, + instructions: str | None = None, + output: ResponseSchemaInput | None = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, + seed: int | None = None, + reasoning_effort: str | None = None, + reasoning_budget_tokens: int | None = None, + tool_choice: ToolChoice | None = None, + tools: Sequence[ToolDeclaration] | None = None, + provider_options: dict[str, dict[str, Any]] | None = None, + ) -> OutputCollection: + """Run source-pattern prompts using the session's provider instance.""" + self._ensure_open() + prompt_tuple = ( + (prompts,) if isinstance(prompts, (str, type(None))) else tuple(prompts) + ) + if environment is not None: + if sources or instructions is not None or tools is not None: + raise ConfigurationError( + "environment cannot be combined with inline instructions/sources/tools", + hint="Put instructions, sources, and tools on the Environment, " + "or drop the environment argument.", + ) + resolved_environment = environment + else: + resolved_environment = Environment( + instructions=instructions, + sources=tuple(sources), + tools=tuple(tools) if tools else (), + ) + inputs = [Input(content=prompt) for prompt in prompt_tuple] + requirements = _build_requirements( + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + provider_options=provider_options, + ) + return await execute_interactions( + resolved_environment, inputs, requirements, self.config, self._provider + ) + + async def check_ready(self) -> ProviderReadiness: + """Return provider readiness for this session's config.""" + self._ensure_open() + if isinstance(self._provider, ReadinessProvider): + return await self._provider.check_ready(model=self.config.model) + return ProviderReadiness( + ready=True, + provider=self.config.provider, + model=self.config.model, + message="Provider has no explicit readiness probe.", + model_verified=False, + ) + + async def aclose(self) -> None: + """Close the session's provider resources.""" + if self._closed: + return + self._closed = True + await _close_provider(self._provider) + + +async def check_ready(config: Config) -> ProviderReadiness: + """Run a fast provider readiness probe and close provider resources.""" + provider = _get_provider(config) + try: + if isinstance(provider, ReadinessProvider): + return await provider.check_ready(model=config.model) + return ProviderReadiness( + ready=True, + provider=config.provider, + model=config.model, + message="Provider has no explicit readiness probe.", + model_verified=False, + ) finally: await _close_provider(provider) +def local_reasoning(*, enabled: bool = False) -> dict[str, dict[str, Any]]: + """Return local provider options for servers with ``enable_thinking`` support.""" + return {"local": {"chat_template_kwargs": {"enable_thinking": enabled}}} + + async def defer( prompts: str | Sequence[str | None] | None = None, *, @@ -475,42 +665,23 @@ async def run_many( for answer in results.answers: print(answer) """ - prompt_tuple = ( - (prompts,) if isinstance(prompts, (str, type(None))) else tuple(prompts) - ) - if environment is not None: - if sources or instructions is not None or tools is not None: - raise ConfigurationError( - "environment cannot be combined with inline instructions/sources/tools", - hint="Put instructions, sources, and tools on the Environment, " - "or drop the environment argument.", - ) - resolved_environment = environment - else: - resolved_environment = Environment( + async with Session(config) as session: + return await session.run_many( + prompts, + sources=sources, + environment=environment, instructions=instructions, - sources=tuple(sources), - tools=tuple(tools) if tools else (), + output=output, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning_budget_tokens=reasoning_budget_tokens, + tool_choice=tool_choice, + tools=tools, + provider_options=provider_options, ) - inputs = [Input(content=prompt) for prompt in prompt_tuple] - requirements = _build_requirements( - output=output, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - seed=seed, - reasoning_effort=reasoning_effort, - reasoning_budget_tokens=reasoning_budget_tokens, - tool_choice=tool_choice, - provider_options=provider_options, - ) - provider = _get_provider(config) - try: - return await execute_interactions( - resolved_environment, inputs, requirements, config, provider - ) - finally: - await _close_provider(provider) async def prepare_environment( @@ -799,8 +970,10 @@ def _resolve_deferred_provider(handle: DeferredHandle) -> Provider: "OutputRequirements", "PlanningError", "PolluxError", + "ProviderReadiness", "RateLimitError", "RetryPolicy", + "Session", "Source", "SourceError", "ToolCall", @@ -810,10 +983,12 @@ def _resolve_deferred_provider(handle: DeferredHandle) -> Provider: "ToolDeclaration", "ToolResult", "cancel_deferred", + "check_ready", "collect_deferred", "defer", "inspect_deferred", "interact", + "local_reasoning", "prepare_environment", "run", "run_many", diff --git a/src/pollux/providers/base.py b/src/pollux/providers/base.py index 47f0a1cf..495b78ab 100644 --- a/src/pollux/providers/base.py +++ b/src/pollux/providers/base.py @@ -37,6 +37,18 @@ class ProviderCapabilities: file_rejection_hint: str | None = None +@dataclass(frozen=True) +class ProviderReadiness: + """Provider preflight status for agent loops and health checks.""" + + ready: bool + provider: str + status_code: int | None = None + message: str | None = None + model: str | None = None + model_verified: bool | None = None + + @dataclass(frozen=True) class ProviderDeferredHandle: """Provider-owned handle returned at deferred submission time.""" @@ -180,6 +192,15 @@ async def validate_request( ... +@runtime_checkable +class ReadinessProvider(Protocol): + """Optional provider hook for fast readiness probes.""" + + async def check_ready(self, *, model: str | None = None) -> ProviderReadiness: + """Return provider readiness without raising for ordinary not-ready states.""" + ... + + @runtime_checkable class DeferredProvider(Protocol): """Lifecycle operations for provider-backed deferred delivery.""" diff --git a/src/pollux/providers/local.py b/src/pollux/providers/local.py index 8a3ba0a3..e05fb4d6 100644 --- a/src/pollux/providers/local.py +++ b/src/pollux/providers/local.py @@ -47,7 +47,7 @@ serialize_tool_calls, ) from pollux.providers._utils import merge_provider_options, to_strict_schema -from pollux.providers.base import ProviderCapabilities +from pollux.providers.base import ProviderCapabilities, ProviderReadiness from pollux.providers.models import ( Message, ProviderFileAsset, @@ -125,6 +125,72 @@ def _get_client(self) -> httpx.AsyncClient: ) return cast("httpx.AsyncClient", self._client) + async def check_ready(self, *, model: str | None = None) -> ProviderReadiness: + """Probe a local OpenAI-compatible server without sending a prompt. + + Pollux checks the server and, when *model* is provided, verifies that the + configured model is listed by ``/v1/models`` before returning ready. + Ordinary connection failures and non-2xx responses are returned as + not-ready statuses so agent harnesses can fail before packing context. + """ + client = self._get_client() + try: + health_response = await client.get("../health") + health_ready = health_response.status_code < 400 + health_message = _readiness_message(health_response) + if model is None: + return ProviderReadiness( + ready=health_ready, + provider="local", + status_code=health_response.status_code, + message=health_message, + model=None, + model_verified=None, + ) + + response = await client.get("models") + if response.status_code >= 400: + return ProviderReadiness( + ready=False, + provider="local", + status_code=response.status_code, + message=_readiness_message(response) or health_message, + model=model, + model_verified=False, + ) + models = _model_ids(response) + if model in models: + return ProviderReadiness( + ready=True, + provider="local", + status_code=response.status_code, + message=health_message or _readiness_message(response), + model=model, + model_verified=True, + ) + return ProviderReadiness( + ready=False, + provider="local", + status_code=response.status_code, + message=f"Model {model!r} was not listed by the local server.", + model=model, + model_verified=False, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + hint = _hint_for_local_error(exc, base_url=self._base_url) + return ProviderReadiness( + ready=False, + provider="local", + status_code=getattr( + getattr(exc, "response", None), "status_code", None + ), + message=hint or str(exc) or "Local provider readiness probe failed.", + model=model, + model_verified=False if model is not None else None, + ) + async def validate_request( self, snapshot: EnvironmentSnapshot, @@ -621,6 +687,40 @@ def _parse_response( ) +def _readiness_message(response: httpx.Response) -> str | None: + """Extract a compact human-readable readiness message.""" + try: + data = response.json() + except Exception: + text = response.text.strip() + return text or None + if isinstance(data, dict): + status = data.get("status") + if isinstance(status, str) and status: + return status + message = extract_error_message(response) + return message if message else None + return None + + +def _model_ids(response: httpx.Response) -> set[str]: + """Extract model ids from an OpenAI-compatible models response.""" + try: + data = response.json() + except Exception: + return set() + if not isinstance(data, dict): + return set() + raw_models = data.get("data") + if not isinstance(raw_models, list): + return set() + ids: set[str] = set() + for item in raw_models: + if isinstance(item, Mapping) and isinstance(item.get("id"), str): + ids.add(cast("str", item["id"])) + return ids + + def _raise_sse_error_if_present( data: Mapping[str, Any], *, tools_present: bool ) -> None: diff --git a/tests/interaction/test_interact_frontdoor.py b/tests/interaction/test_interact_frontdoor.py index c118f673..bde1404d 100644 --- a/tests/interaction/test_interact_frontdoor.py +++ b/tests/interaction/test_interact_frontdoor.py @@ -8,7 +8,7 @@ import pollux from pollux import Environment, Input, Output, ToolDeclaration, ToolResult, interact from pollux.config import Config -from pollux.providers.base import ProviderCapabilities +from pollux.providers.base import ProviderCapabilities, ProviderReadiness from pollux.providers.models import ProviderResponse from pollux.providers.models import ToolCall as ProviderToolCall from tests.conftest import ANTHROPIC_MODEL, FakeProvider @@ -120,3 +120,52 @@ async def test_agent_loop_continues_from_tool_results(monkeypatch): ) assert final.text == "It is sunny." assert not final.tool_calls + + +@pytest.mark.asyncio +async def test_session_reuses_provider_and_closes(monkeypatch): + class CloseableScriptedProvider(ScriptedProvider): + closed: bool = False + + async def aclose(self) -> None: + self.closed = True + + provider = CloseableScriptedProvider( + script=[ + ProviderResponse(text="one", usage={"total_tokens": 1}), + ProviderResponse(text="two", usage={"total_tokens": 1}), + ] + ) + monkeypatch.setattr(pollux, "_get_provider", lambda _config: provider) + + async with pollux.Session(_cfg()) as session: + first = await session.interact(Environment(), Input("Q1")) + second = await session.interact(Environment(), Input("Q2")) + + assert first.text == "one" + assert second.text == "two" + assert provider.generate_calls == 2 + assert provider.closed is True + + +@pytest.mark.asyncio +async def test_session_check_ready_uses_provider_probe(monkeypatch): + class ReadyProvider(ScriptedProvider): + async def check_ready(self, *, model: str | None = None) -> ProviderReadiness: + return ProviderReadiness( + ready=True, provider="anthropic", model=model, message="ok" + ) + + provider = ReadyProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: provider) + + readiness = await pollux.check_ready(_cfg()) + + assert readiness.ready is True + assert readiness.model == ANTHROPIC_MODEL + + +def test_local_reasoning_returns_scoped_provider_options() -> None: + assert pollux.local_reasoning(enabled=False) == { + "local": {"chat_template_kwargs": {"enable_thinking": False}} + } diff --git a/tests/providers/test_local_contract.py b/tests/providers/test_local_contract.py index 09bdc495..b8c7279e 100644 --- a/tests/providers/test_local_contract.py +++ b/tests/providers/test_local_contract.py @@ -48,11 +48,14 @@ def __init__( payload: Any = None, status_code: int = 200, error_body: Any = None, + get_payloads: dict[str, tuple[int, Any]] | None = None, ) -> None: self.last_json: dict[str, Any] | None = None self.closed = False self._status_code = status_code self._error_body = error_body + self.get_paths: list[str] = [] + self._get_payloads = get_payloads or {} self._payload = payload or { "id": "chatcmpl_local_1", "choices": [ @@ -79,6 +82,12 @@ async def post(self, path: str, json: dict[str, Any]) -> Any: ) return httpx.Response(self._status_code, json=self._payload, request=request) + async def get(self, path: str) -> Any: + self.get_paths.append(path) + status_code, payload = self._get_payloads.get(path, (404, {"status": "no"})) + request = httpx.Request("GET", f"{_LOCAL_BASE_URL}{path}") + return httpx.Response(status_code, json=payload, request=request) + async def aclose(self) -> None: self.closed = True @@ -681,6 +690,88 @@ async def test_local_generate_classifies_tool_call_parse_http_error() -> None: assert exc.value.error_category == "tool_call_parse" +@pytest.mark.asyncio +async def test_local_check_ready_uses_health_endpoint() -> None: + fake = _FakeLocalClient( + get_payloads={ + "../health": (200, {"status": "ok"}), + "models": (200, {"data": [{"id": LOCAL_MODEL}]}), + } + ) + provider = _make_local_provider(fake) + + readiness = await provider.check_ready(model=LOCAL_MODEL) + + assert readiness.ready is True + assert readiness.provider == "local" + assert readiness.status_code == 200 + assert readiness.message == "ok" + assert readiness.model == LOCAL_MODEL + assert readiness.model_verified is True + assert fake.get_paths == ["../health", "models"] + + +@pytest.mark.asyncio +async def test_local_check_ready_falls_back_to_models() -> None: + fake = _FakeLocalClient( + get_payloads={ + "../health": (404, {"error": {"message": "missing"}}), + "models": (200, {"data": [{"id": LOCAL_MODEL}]}), + } + ) + provider = _make_local_provider(fake) + + readiness = await provider.check_ready(model=LOCAL_MODEL) + + assert readiness.ready is True + assert readiness.status_code == 200 + assert readiness.model_verified is True + assert fake.get_paths == ["../health", "models"] + + +@pytest.mark.asyncio +async def test_local_check_ready_reports_missing_model() -> None: + fake = _FakeLocalClient( + get_payloads={ + "../health": (404, {"error": {"message": "missing"}}), + "models": (200, {"data": [{"id": "other-model"}]}), + } + ) + provider = _make_local_provider(fake) + + readiness = await provider.check_ready(model=LOCAL_MODEL) + + assert readiness.ready is False + assert readiness.model == LOCAL_MODEL + assert readiness.model_verified is False + assert "not listed" in (readiness.message or "") + + +@pytest.mark.asyncio +async def test_local_check_ready_skips_model_probe_when_model_unset() -> None: + fake = _FakeLocalClient(get_payloads={"../health": (200, {"status": "ok"})}) + provider = _make_local_provider(fake) + + readiness = await provider.check_ready(model=None) + + assert readiness.ready is True + assert readiness.model is None + assert readiness.model_verified is None + assert fake.get_paths == ["../health"] + + +@pytest.mark.asyncio +async def test_local_check_ready_requires_model_verification() -> None: + fake = _FakeLocalClient(get_payloads={"../health": (200, {"status": "ok"})}) + provider = _make_local_provider(fake) + + readiness = await provider.check_ready(model=LOCAL_MODEL) + + assert readiness.ready is False + assert readiness.model_verified is False + assert fake.get_paths == ["../health", "models"] + + @pytest.mark.asyncio async def test_local_aclose_closes_client() -> None: """aclose must close the underlying httpx client and be idempotent."""