Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class AgentConfig[T]:
output_schema: Any | None = None
model_settings: dict[str, Any] | None = None
input_guardrails: list[Any] | None = None
strict_json_schema: bool = True


@dataclass(frozen=True, slots=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def agent(
output_schema: Any | None = None,
max_turns: int | None = None,
branch_param: str | None = None,
strict_json_schema: bool = True,
) -> Any:
"""Method decorator that turns a method into an agent invocation.

Expand Down Expand Up @@ -122,6 +123,7 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> AgentRunResult[Any]:
tools=resolved_tools,
output_schema=output_schema,
model_settings=model_settings,
strict_json_schema=strict_json_schema,
)

run_config = AgentRunConfig(
Expand Down Expand Up @@ -161,6 +163,7 @@ def consensus_agent(
consensus_strategy: ConsensusStrategy,
judge: Callable[..., Any] | None = None,
temperature_spread: tuple[float, float] | None = None,
strict_json_schema: bool = True,
) -> Any:
"""Method decorator for consensus-based multi-run agent invocation."""

Expand Down Expand Up @@ -223,6 +226,7 @@ async def single_run(run_index: int) -> AgentRunResult[Any]:
tools=resolved_tools,
output_schema=output_schema,
model_settings=ms or None,
strict_json_schema=strict_json_schema,
)
run_config = AgentRunConfig(
input=mapped.input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ async def create_and_run(
instructions=agent_config.instructions,
model=model,
tools=agent_config.tools or [],
output_type=AgentOutputSchema(agent_config.output_schema, strict_json_schema=False)
output_type=AgentOutputSchema(
agent_config.output_schema,
strict_json_schema=agent_config.strict_json_schema,
)
if agent_config.output_schema
else None,
model_settings=ms or ModelSettings(),
Expand Down
34 changes: 34 additions & 0 deletions packages/agentic-workflows/tests/unit/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,40 @@ async def spy(config: AgentConfig[str], run_config: AgentRunConfig) -> AgentRunR
assert len(captured_configs) == 1
assert captured_configs[0].model == "test-model-override"

async def test_strict_json_schema_defaults_true_and_opt_out(self, prompts_dir: str) -> None:
(Path(prompts_dir) / "strict_default.md").write_text("Strict default")
(Path(prompts_dir) / "strict_off.md").write_text("Strict off")

@agentic_workflow(prompts_directory=prompts_dir)
class Wf:
def __init__(self, service: AiAgentServiceLocal) -> None:
self._ai_agent_service = service

@agent(output_schema=dict)
async def strict_default(self, input_text: str) -> AgentRunResult[dict]: ...

@agent(output_schema=dict, strict_json_schema=False)
async def strict_off(self, input_text: str) -> AgentRunResult[dict]: ...

service = AiAgentServiceLocal.get_instance()
captured: list[AgentConfig[dict]] = []
original = service.create_and_run

async def spy(
config: AgentConfig[dict], run_config: AgentRunConfig
) -> AgentRunResult[dict]:
captured.append(config)
return await original(config, run_config)

service.create_and_run = spy # type: ignore[assignment]

wf = Wf(service)
await wf.strict_default("x")
await wf.strict_off("x")

assert captured[0].strict_json_schema is True
assert captured[1].strict_json_schema is False


# ---------------------------------------------------------------------------
# @consensus_agent validation tests
Expand Down
154 changes: 154 additions & 0 deletions packages/agentic-workflows/tests/unit/test_service_openai_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Unit tests for AiAgentServiceOpenAICompat strict JSON + OpenRouter params."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, ClassVar
from unittest.mock import patch

import pytest
from agents.model_settings import ModelSettings
from zeroshot_agentic_workflows import AgentConfig, AgentRunConfig
from zeroshot_agentic_workflows import service_openai_compat as svc_mod


@dataclass
class _Schema:
answer: str


class _FakeAgentOutputSchema:
"""Recording fake for AgentOutputSchema(output_type, strict_json_schema=...)."""

instances: ClassVar[list[_FakeAgentOutputSchema]] = []

def __init__(self, output_type: Any, strict_json_schema: bool = True) -> None:
self.output_type = output_type
self.strict_json_schema = strict_json_schema
_FakeAgentOutputSchema.instances.append(self)


class _FakeAgent:
instances: ClassVar[list[_FakeAgent]] = []

def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
_FakeAgent.instances.append(self)


class _FakeResult:
def __init__(self, output: Any) -> None:
self.final_output = output


class _FakeRunner:
@staticmethod
async def run(_agent: Any, **_kwargs: Any) -> _FakeResult:
return _FakeResult(output=_Schema(answer="ok"))


@pytest.fixture(autouse=True)
def _patch_sdk() -> Any:
_FakeAgentOutputSchema.instances.clear()
_FakeAgent.instances.clear()
with (
patch.object(svc_mod, "AgentOutputSchema", _FakeAgentOutputSchema),
patch.object(svc_mod, "Agent", _FakeAgent),
patch.object(svc_mod, "Runner", _FakeRunner),
patch.object(svc_mod, "OpenAIChatCompletionsModel", lambda **kw: ("model", kw)),
):
yield


def _make_service() -> svc_mod.AiAgentServiceOpenAICompat:
return svc_mod.AiAgentServiceOpenAICompat(
base_url="https://openrouter.ai/api/v1",
api_key="test-key",
default_model="openai/gpt-5",
)


class TestStrictJsonSchema:
async def test_defaults_to_strict_true(self) -> None:
service = _make_service()
config = AgentConfig[_Schema](
name="t",
instructions="hi",
output_schema=_Schema,
)

result = await service.create_and_run(config, AgentRunConfig(input="x"))

assert result.success is True
assert len(_FakeAgentOutputSchema.instances) == 1
assert _FakeAgentOutputSchema.instances[0].strict_json_schema is True
assert _FakeAgentOutputSchema.instances[0].output_type is _Schema

async def test_caller_can_opt_out(self) -> None:
service = _make_service()
config = AgentConfig[_Schema](
name="t",
instructions="hi",
output_schema=_Schema,
strict_json_schema=False,
)

result = await service.create_and_run(config, AgentRunConfig(input="x"))

assert result.success is True
assert len(_FakeAgentOutputSchema.instances) == 1
assert _FakeAgentOutputSchema.instances[0].strict_json_schema is False

async def test_no_output_schema_means_no_output_type(self) -> None:
service = _make_service()
config = AgentConfig[str](name="t", instructions="hi")

result = await service.create_and_run(config, AgentRunConfig(input="x"))

assert result.success is True
assert _FakeAgentOutputSchema.instances == []
assert _FakeAgent.instances[0].kwargs["output_type"] is None


class TestOpenRouterParams:
async def test_extra_body_flows_into_model_settings(self) -> None:
service = _make_service()
extra_body = {
"provider": {
"require_parameters": True,
"allow_fallbacks": False,
"order": ["OpenAI"],
},
"plugins": [{"id": "response-healing"}],
}
config = AgentConfig[_Schema](
name="t",
instructions="hi",
output_schema=_Schema,
model_settings={
"tool_choice": "none",
"extra_body": extra_body,
},
)

result = await service.create_and_run(config, AgentRunConfig(input="x"))

assert result.success is True
ms = _FakeAgent.instances[0].kwargs["model_settings"]
assert isinstance(ms, ModelSettings)
assert ms.extra_body == extra_body
assert ms.tool_choice == "none"

async def test_no_model_settings_yields_default_model_settings(self) -> None:
service = _make_service()
config = AgentConfig[_Schema](
name="t",
instructions="hi",
output_schema=_Schema,
)

await service.create_and_run(config, AgentRunConfig(input="x"))

ms = _FakeAgent.instances[0].kwargs["model_settings"]
assert isinstance(ms, ModelSettings)
assert ms.extra_body is None
16 changes: 8 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.