Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 241 additions & 66 deletions src/pollux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@
resolve_persistent_cache,
stream_interaction,
)
from pollux.providers.base import CloseableProvider
from pollux.providers.base import (
CloseableProvider,
ProviderReadiness,
ReadinessProvider,
)
from pollux.retry import RetryPolicy
from pollux.source import Source

Expand Down Expand Up @@ -253,24 +257,20 @@ async def interact(
config=cfg,
)
"""
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
provider = _get_provider(config)
try:
return await execute_interaction(
environment, input, requirements, config, provider
async with Session(config) as session:
return await session.interact(
environment,
input,
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
finally:
await _close_provider(provider)


async def stream(
Expand Down Expand Up @@ -325,27 +325,217 @@ async def stream(
elif event.type == "done":
result = event.output
"""
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
provider = _get_provider(config)
try:
async with Session(config) as session:
async for event in session.stream(
environment,
input,
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
):
yield event


class Session:
"""Reusable Pollux runtime for multi-turn agent loops.

The one-shot helpers create and close a provider per call. ``Session`` owns a
single provider instance so clients with many sequential turns can reuse
transport resources while still going through the public interaction APIs.
"""

def __init__(self, config: Config) -> None:
self.config = config
self._provider = _get_provider(config)
self._closed = False

async def __aenter__(self) -> Session: # noqa: PYI034
"""Enter the async context manager."""
return self

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc: BaseException | None,
_tb: object,
) -> None:
"""Close provider resources when leaving the async context manager."""
await self.aclose()

def _ensure_open(self) -> None:
if self._closed:
raise ConfigurationError(
"Pollux Session is closed",
hint="Create a new Session for additional interactions.",
)

async def interact(
self,
environment: Environment,
input: Input, # noqa: A002 - "input" is the canonical v2 primitive name
*,
output: ResponseSchemaInput | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_tokens: int | None = None,
seed: int | None = None,
reasoning_effort: str | None = None,
reasoning_budget_tokens: int | None = None,
tool_choice: ToolChoice | None = None,
provider_options: dict[str, dict[str, Any]] | None = None,
) -> Output:
"""Run one interaction using the session's provider instance."""
self._ensure_open()
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
return await execute_interaction(
environment, input, requirements, self.config, self._provider
)

async def stream(
self,
environment: Environment,
input: Input, # noqa: A002 - "input" is the canonical v2 primitive name
*,
output: ResponseSchemaInput | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_tokens: int | None = None,
seed: int | None = None,
reasoning_effort: str | None = None,
reasoning_budget_tokens: int | None = None,
tool_choice: ToolChoice | None = None,
provider_options: dict[str, dict[str, Any]] | None = None,
) -> AsyncIterator[Event]:
"""Stream one interaction using the session's provider instance."""
self._ensure_open()
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
async for event in stream_interaction(
environment, input, requirements, config, provider
environment, input, requirements, self.config, self._provider
):
yield event

async def run_many(
self,
prompts: str | Sequence[str | None] | None = None,
*,
sources: Sequence[Source] = (),
environment: Environment | None = None,
instructions: str | None = None,
output: ResponseSchemaInput | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_tokens: int | None = None,
seed: int | None = None,
reasoning_effort: str | None = None,
reasoning_budget_tokens: int | None = None,
tool_choice: ToolChoice | None = None,
tools: Sequence[ToolDeclaration] | None = None,
provider_options: dict[str, dict[str, Any]] | None = None,
) -> OutputCollection:
"""Run source-pattern prompts using the session's provider instance."""
self._ensure_open()
prompt_tuple = (
(prompts,) if isinstance(prompts, (str, type(None))) else tuple(prompts)
)
if environment is not None:
if sources or instructions is not None or tools is not None:
raise ConfigurationError(
"environment cannot be combined with inline instructions/sources/tools",
hint="Put instructions, sources, and tools on the Environment, "
"or drop the environment argument.",
)
resolved_environment = environment
else:
resolved_environment = Environment(
instructions=instructions,
sources=tuple(sources),
tools=tuple(tools) if tools else (),
)
inputs = [Input(content=prompt) for prompt in prompt_tuple]
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
return await execute_interactions(
resolved_environment, inputs, requirements, self.config, self._provider
)

async def check_ready(self) -> ProviderReadiness:
"""Return provider readiness for this session's config."""
self._ensure_open()
if isinstance(self._provider, ReadinessProvider):
return await self._provider.check_ready(model=self.config.model)
return ProviderReadiness(
ready=True,
provider=self.config.provider,
model=self.config.model,
message="Provider has no explicit readiness probe.",
model_verified=False,
)

async def aclose(self) -> None:
"""Close the session's provider resources."""
if self._closed:
return
self._closed = True
await _close_provider(self._provider)


async def check_ready(config: Config) -> ProviderReadiness:
"""Run a fast provider readiness probe and close provider resources."""
provider = _get_provider(config)
try:
if isinstance(provider, ReadinessProvider):
return await provider.check_ready(model=config.model)
return ProviderReadiness(
ready=True,
provider=config.provider,
model=config.model,
message="Provider has no explicit readiness probe.",
model_verified=False,
)
finally:
await _close_provider(provider)


def local_reasoning(*, enabled: bool = False) -> dict[str, dict[str, Any]]:
"""Return local provider options for servers with ``enable_thinking`` support."""
return {"local": {"chat_template_kwargs": {"enable_thinking": enabled}}}


async def defer(
prompts: str | Sequence[str | None] | None = None,
*,
Expand Down Expand Up @@ -475,42 +665,23 @@ async def run_many(
for answer in results.answers:
print(answer)
"""
prompt_tuple = (
(prompts,) if isinstance(prompts, (str, type(None))) else tuple(prompts)
)
if environment is not None:
if sources or instructions is not None or tools is not None:
raise ConfigurationError(
"environment cannot be combined with inline instructions/sources/tools",
hint="Put instructions, sources, and tools on the Environment, "
"or drop the environment argument.",
)
resolved_environment = environment
else:
resolved_environment = Environment(
async with Session(config) as session:
return await session.run_many(
prompts,
sources=sources,
environment=environment,
instructions=instructions,
sources=tuple(sources),
tools=tuple(tools) if tools else (),
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
tools=tools,
provider_options=provider_options,
)
inputs = [Input(content=prompt) for prompt in prompt_tuple]
requirements = _build_requirements(
output=output,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
seed=seed,
reasoning_effort=reasoning_effort,
reasoning_budget_tokens=reasoning_budget_tokens,
tool_choice=tool_choice,
provider_options=provider_options,
)
provider = _get_provider(config)
try:
return await execute_interactions(
resolved_environment, inputs, requirements, config, provider
)
finally:
await _close_provider(provider)


async def prepare_environment(
Expand Down Expand Up @@ -799,8 +970,10 @@ def _resolve_deferred_provider(handle: DeferredHandle) -> Provider:
"OutputRequirements",
"PlanningError",
"PolluxError",
"ProviderReadiness",
"RateLimitError",
"RetryPolicy",
"Session",
"Source",
"SourceError",
"ToolCall",
Expand All @@ -810,10 +983,12 @@ def _resolve_deferred_provider(handle: DeferredHandle) -> Provider:
"ToolDeclaration",
"ToolResult",
"cancel_deferred",
"check_ready",
"collect_deferred",
"defer",
"inspect_deferred",
"interact",
"local_reasoning",
"prepare_environment",
"run",
"run_many",
Expand Down
21 changes: 21 additions & 0 deletions src/pollux/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ class ProviderCapabilities:
file_rejection_hint: str | None = None


@dataclass(frozen=True)
class ProviderReadiness:
"""Provider preflight status for agent loops and health checks."""

ready: bool
provider: str
status_code: int | None = None
message: str | None = None
model: str | None = None
model_verified: bool | None = None


@dataclass(frozen=True)
class ProviderDeferredHandle:
"""Provider-owned handle returned at deferred submission time."""
Expand Down Expand Up @@ -180,6 +192,15 @@ async def validate_request(
...


@runtime_checkable
class ReadinessProvider(Protocol):
"""Optional provider hook for fast readiness probes."""

async def check_ready(self, *, model: str | None = None) -> ProviderReadiness:
"""Return provider readiness without raising for ordinary not-ready states."""
...


@runtime_checkable
class DeferredProvider(Protocol):
"""Lifecycle operations for provider-backed deferred delivery."""
Expand Down
Loading
Loading