diff --git a/docs/guides/integrations.mdx b/docs/guides/integrations.mdx index efe008741..c0f4a9a07 100644 --- a/docs/guides/integrations.mdx +++ b/docs/guides/integrations.mdx @@ -73,6 +73,70 @@ async with hud.eval(eval) as ctx: await ctx.submit(msg.content or "") ``` +### Chat Completions (Single-Call Runner) + +If you want HUD to handle the chat tool loop for a scenario task, use +`hud.run_scenario_chat(...)`: + +```python +import os +from openai import AsyncOpenAI +import hud + +env = hud.Environment("trivia") +task = env("initials", company="HUD") + +client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +result = await hud.run_scenario_chat( + client=client, + model="gpt-4o", + task=task, + api="chat_completions", # or "responses" / "auto" +) + +print(result.answer) +print(result.reward) +print(result.trace_id) +``` + +### Interactive Scenario Chat (Turn-by-Turn) + +Use `hud.run_scenario_chat_interactive(...)` when you want to send multiple +user turns before final evaluation: + +```python +import os +from openai import AsyncOpenAI +import hud + +env = hud.Environment("trivia") + +client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.run_scenario_chat_interactive( + client=client, + model="gpt-4o", + env=env, + scenario="initials", + args={"company": "HUD"}, +) as chat: + first = await chat.send("Start with your initial investigation.") + follow_up = await chat.send("Now provide a concise final answer.") + result = await chat.finish() # submits + evaluates + +print(first.answer) +print(follow_up.answer) +print(result.reward) +print(result.trace_id) +``` + ### Responses API ```python @@ -111,6 +175,112 @@ Requires: `pip install openai-agents` --- +## Serve Scenarios as an HTTP Endpoint + +If you want external agents to run your scenarios without the HUD SDK, use +`env.serve_as_agent()`. It starts a local OpenAI-compatible server — any +OpenAI client in any language can connect. + +### Server (`04_scenario_server.py`) + +```python +import os +import hud +from openai import AsyncOpenAI + +env = hud.Environment(os.environ["HUD_ENV_NAME"]) +env.connect_hub(os.environ["HUD_ENV_NAME"]) + +env.serve_as_agent( + client=AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"], + ), + model="gpt-4o", + port=8321, +) +``` + +The server exposes: + +| Endpoint | Purpose | +|---|---| +| `GET /scenarios` | List available scenarios and their required args | +| `GET /v1/lifecycle-tools` | List scenario lifecycle tool schemas | +| `POST /v1/lifecycle-tools/call` | Call lifecycle tools (`scenario_list/start/send/finish`) | +| `POST /v1/chat/completions` | Start or continue a session | +| `POST /v1/sessions/{id}/finish` | Submit and evaluate | +| `GET /v1/sessions` | List active sessions | +| `GET /mcp/tools` | MCP-native lifecycle tool list | +| `POST /mcp/tools/call` | MCP-native lifecycle tool execution | + +### Client (`05_scenario_client.py`) + +No HUD SDK needed. Use any standard OpenAI client: + +```python +import httpx +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8321/v1", api_key="not-needed") + +# 1. Discover scenarios +scenarios = httpx.get("http://localhost:8321/scenarios").json()["scenarios"] +selected = scenarios[0] + +# 2. First turn — pass scenario name and args in the request body +# (both fields are required for session bootstrap) +first = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Begin."}], + extra_body={ + "scenario": selected["short_name"], + "scenario_args": {"arg": "value"}, + }, +) +session_id = first.hud["session_id"] # returned in every response + +# 3. Follow-up turns — pass session ID in the header +follow_up = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "What are the root causes?"}], + extra_headers={"X-HUD-Session-Id": session_id}, +) + +# You can also pass `thread_id` / `conversation_id` in `extra_body`. +follow_up_alt = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Any remaining risks?"}], + extra_body={"thread_id": session_id}, +) + +# 4. Finish — submits the answer and returns reward + trace URL +result = httpx.post(f"http://localhost:8321/v1/sessions/{session_id}/finish").json() +print(result["reward"], result["trace_url"]) +``` + +Streaming works the same way — just pass `stream=True`. The server sends +standard SSE chunks, with a final chunk carrying `hud.session_id` and +`hud.trace_url`. + +### Lifecycle Tools (Agent-native Helpers) + +If your orchestrator prefers explicit lifecycle calls, use: + +- `GET /v1/lifecycle-tools` + `POST /v1/lifecycle-tools/call` +- or the MCP-native aliases: `GET /mcp/tools` + `POST /mcp/tools/call` + +Available tool names: + +- `scenario_list` +- `scenario_start` (requires `scenario` + `scenario_args`) +- `scenario_send` +- `scenario_finish` + +Requires: `pip install hud-python[server]` (installs `fastapi` and `uvicorn`) + +--- + ## Anthropic Claude's Messages API with tool use. diff --git a/examples/03_scenario_chat.py b/examples/03_scenario_chat.py new file mode 100644 index 000000000..b862f02ac --- /dev/null +++ b/examples/03_scenario_chat.py @@ -0,0 +1,117 @@ +"""Interactive REPL for scenario chat with optional streaming. + +Usage: + HUD_API_KEY=... HUD_ENV_NAME=... python examples/03_interactive_repl.py + HUD_API_KEY=... HUD_ENV_NAME=... python examples/03_interactive_repl.py --stream +""" + +from __future__ import annotations + +import argparse +import asyncio +import os + +import hud +from openai import AsyncOpenAI + +TURN_TIMEOUT_SECONDS = 60 + + +def _parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--stream", action="store_true", help="Enable SSE token streaming.") + parser.add_argument( + "--env", + default=os.getenv("HUD_ENV_NAME"), + help="HUD environment name (or set HUD_ENV_NAME).", + ) + parser.add_argument( + "--model", + default=os.getenv("HUD_MODEL", "gpt-4o"), + help="Model name for chat calls.", + ) + return parser + + +async def main() -> None: + args = _parser().parse_args() + if not args.env: + raise ValueError("Provide --env or set HUD_ENV_NAME") + + client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"], + ) + env = hud.Environment(args.env) + env.connect_hub(args.env) + + async with env: + scenarios = await env.list_scenarios() + if not scenarios: + print("No scenarios found.") + return + + print("Available scenarios:") + for i, scenario in enumerate(scenarios, 1): + req = ", ".join(scenario.required_args) or "(none)" + print(f" [{i}] {scenario.short_name} - {scenario.description or 'no description'}") + print(f" required args: {req}") + print() + + choice = input("Pick a scenario (number, default 1): ").strip() + idx = int(choice) - 1 if choice.isdigit() else 0 + chosen = scenarios[idx] if 0 <= idx < len(scenarios) else scenarios[0] + + scenario_args: dict[str, str] = {} + for arg in chosen.arguments: + label = arg.name if arg.required else f"{arg.name} (optional)" + value = input(f" {label}: ").strip() + if value: + scenario_args[arg.name] = value + + print(f"\nRunning: {chosen.short_name}") + print(f"Streaming: {'on' if args.stream else 'off'}") + print("Type /done when finished.\n") + + async with hud.run_scenario_chat_interactive( + client=client, + model=args.model, + env=env, + scenario=chosen.short_name, + args=scenario_args, + api="chat_completions", + ) as chat: + print(f"Trace: https://hud.ai/trace/{chat.trace_id}\n") + + async def send_message(msg: str) -> None: + if args.stream: + print("Assistant: ", end="", flush=True) + async for event in chat.send_stream(msg): + if event.type == "text_delta": + print(event.content, end="", flush=True) + print("\n") + return + turn = await asyncio.wait_for(chat.send(msg), timeout=TURN_TIMEOUT_SECONDS) + print(f"Assistant: {turn.answer}\n") + + await send_message("Begin.") + while True: + try: + user_input = input("You: ").strip() + except (EOFError, KeyboardInterrupt): + print() + break + if not user_input: + continue + if user_input.lower() in {"/done", "/quit", "/exit"}: + break + await send_message(user_input) + + result = await chat.finish() + print("---") + print(f"Reward: {result.reward}") + print(f"Trace: https://hud.ai/trace/{result.trace_id}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/04_scenario_server.py b/examples/04_scenario_server.py new file mode 100644 index 000000000..0fd9f2d75 --- /dev/null +++ b/examples/04_scenario_server.py @@ -0,0 +1,48 @@ +"""Serve HUD scenarios as an OpenAI-compatible agent endpoint. + +Usage: + HUD_API_KEY=... HUD_ENV_NAME=... python examples/04_scenario_server.py +""" + +from __future__ import annotations + +import os + +import hud +from openai import AsyncOpenAI + + +def main() -> None: + env_name = os.getenv("HUD_ENV_NAME") + if not env_name: + raise ValueError("Set HUD_ENV_NAME to the target HUD environment") + + model = os.getenv("HUD_MODEL", "gpt-4o") + port = int(os.getenv("HUD_AGENT_PORT", "8321")) + + client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"], + ) + env = hud.Environment(env_name) + env.connect_hub(env_name) + + print(f"Serving {env_name} on http://localhost:{port}") + for route in ( + "GET /scenarios", + "GET /v1/lifecycle-tools", + "POST /v1/lifecycle-tools/call", + "POST /v1/chat/completions (use X-HUD-Session-Id for follow-up turns)", + "POST /v1/sessions//finish", + "GET /v1/sessions", + "GET /mcp/tools", + "POST /mcp/tools/call", + ): + print(route) + print() + + env.serve_as_agent(client=client, model=model, port=port) + + +if __name__ == "__main__": + main() diff --git a/examples/05_scenario_client.py b/examples/05_scenario_client.py new file mode 100644 index 000000000..8c6302079 --- /dev/null +++ b/examples/05_scenario_client.py @@ -0,0 +1,112 @@ +"""External agent demo against HUD agent server. + +Usage: + # Terminal 1 + HUD_API_KEY=... HUD_ENV_NAME=... python examples/04_scenario_server.py + + # Terminal 2 + python examples/05_scenario_client.py +""" + +from __future__ import annotations + +import os + +import httpx +from openai import OpenAI + + +def _stream_turn(client: OpenAI, *, model: str, session_id: str, prompt: str) -> None: + print(f"Prompt: {prompt}") + print("Answer: ", end="", flush=True) + stream = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + stream=True, + extra_headers={"X-HUD-Session-Id": session_id}, + ) + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n") + + +def main() -> None: + base_url = os.getenv("HUD_AGENT_URL", "http://localhost:8321") + model = os.getenv("HUD_MODEL", "gpt-4o") + scenario_name = os.getenv("HUD_SCENARIO", "") + + print("Discovering scenarios...") + response = httpx.get(f"{base_url}/scenarios", timeout=120.0) + response.raise_for_status() + scenarios = response.json()["scenarios"] + if not scenarios: + raise ValueError("No scenarios found on server") + + for scenario in scenarios: + required = ", ".join(scenario["required_args"]) or "(none)" + print(f" {scenario['short_name']} - required: {required}") + print() + + selected = next((s for s in scenarios if s["short_name"] == scenario_name), None) + if selected is None: + selected = scenarios[0] + print( + f"Scenario '{scenario_name}' not found; falling back to '{selected['short_name']}'" + ) + + default_args = { + "id": "example-id", + "task": "Investigate the issue and summarize findings.", + "question": "What are the root causes?", + "input_text": "Example input for scenario setup.", + } + scenario_args = { + arg_name: default_args.get(arg_name, f"example-{arg_name}") + for arg_name in selected.get("required_args", []) + } + print(f"Using scenario: {selected['short_name']}") + + client = OpenAI( + base_url=f"{base_url}/v1", + api_key="not-needed", + timeout=300.0, + ) + + first = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Begin and summarize the initial context."}], + extra_body={ + "scenario": selected["short_name"], + "scenario_args": scenario_args, + }, + ) + first_answer = first.choices[0].message.content or "" + hud_meta = getattr(first, "hud", None) or {} + session_id = hud_meta.get("session_id", "") if isinstance(hud_meta, dict) else "" + trace_url = hud_meta.get("trace_url", "") if isinstance(hud_meta, dict) else "" + if not session_id: + raise ValueError("No session_id returned from server") + + print(f"Session: {session_id}") + print(f"Trace: {trace_url}") + print(f"Answer: {first_answer[:200]}...\n") + + _stream_turn(client, model=model, session_id=session_id, prompt="What are the root causes?") + # Follow-ups can also use thread_id/conversation_id in request body. + second = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Give your top 3 recommendations."}], + extra_body={"thread_id": session_id}, + ) + print(f"Answer: {(second.choices[0].message.content or '')[:200]}...\n") + + finish = httpx.post(f"{base_url}/v1/sessions/{session_id}/finish", timeout=120.0) + finish.raise_for_status() + result = finish.json() + print(f"Reward: {result.get('reward')}") + print(f"Trace URL: {result.get('trace_url')}") + + +if __name__ == "__main__": + main() diff --git a/examples/06_mcp_server.py b/examples/06_mcp_server.py new file mode 100644 index 000000000..c7cf0424a --- /dev/null +++ b/examples/06_mcp_server.py @@ -0,0 +1,241 @@ +"""Scenario-only MCP server (hides raw env tools). + +Run manually: + HUD_API_KEY=... HUD_ENV_NAME=... uv run python examples/06_mcp_server.py + +Required env vars: + HUD_API_KEY API key for HUD + HUD_ENV_NAME Environment name to connect to + +Optional env vars: + MCP_TRANSPORT=stdio|streamable-http|sse (default: streamable-http) + MCP_HOST=0.0.0.0 (for HTTP/SSE) + MCP_PORT=8765 (for HTTP/SSE) +""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +import time +import uuid +from dataclasses import dataclass +from typing import Any + +import hud +from openai import AsyncOpenAI +from hud.scenario_chat import run_scenario_chat_interactive +from hud.server import MCPServer + +ENV_NAME = os.environ["HUD_ENV_NAME"] +TRANSPORT = os.environ.get("MCP_TRANSPORT", "streamable-http") +HOST = os.environ.get("MCP_HOST", "0.0.0.0") +PORT = int(os.environ.get("MCP_PORT", "8765")) +SESSION_TTL_SECONDS = int(os.environ.get("HUD_SCENARIO_SESSION_TTL", "1800")) + + +@dataclass +class _SessionEntry: + chat: Any + cm: Any + last_active: float + + def touch(self) -> None: + self.last_active = time.monotonic() + + @property + def expired(self) -> bool: + return (time.monotonic() - self.last_active) > SESSION_TTL_SECONDS + + +mcp = MCPServer("hud-scenario-chat") +env = hud.Environment(ENV_NAME) +env.connect_hub(ENV_NAME) +sessions: dict[str, _SessionEntry] = {} +cleanup_task: asyncio.Task[Any] | None = None + + +async def _cleanup_expired_sessions() -> None: + while True: + await asyncio.sleep(60) + expired_ids = [sid for sid, entry in sessions.items() if entry.expired] + for sid in expired_ids: + entry = sessions.pop(sid, None) + if entry is None: + continue + try: + await entry.chat.finish() + except Exception: + pass + try: + await entry.cm.__aexit__(None, None, None) + except Exception: + pass + + +async def _start_session( + *, + model: str, + scenario: str, + scenario_args: dict[str, Any], + max_steps: int, +) -> tuple[str, _SessionEntry]: + client = AsyncOpenAI( + base_url=os.environ.get("HUD_INFERENCE_URL", "https://inference.hud.ai"), + api_key=os.environ["HUD_API_KEY"], + ) + cm = run_scenario_chat_interactive( + client=client, + model=model, + env=env, + scenario=scenario, + args=scenario_args, + max_steps=max_steps, + ) + try: + chat = await cm.__aenter__() + except Exception: + with contextlib.suppress(Exception): + await cm.__aexit__(None, None, None) + raise + session_id = uuid.uuid4().hex[:16] + entry = _SessionEntry(chat=chat, cm=cm, last_active=time.monotonic()) + sessions[session_id] = entry + return session_id, entry + + +def _session_meta(session_id: str, trace_id: str) -> dict[str, str]: + return { + "session_id": session_id, + "thread_id": session_id, + "conversation_id": session_id, + "trace_id": trace_id, + "trace_url": f"https://hud.ai/trace/{trace_id}", + } + + +@mcp.initialize +async def _initialize() -> None: + global cleanup_task + await env.__aenter__() + cleanup_task = asyncio.create_task(_cleanup_expired_sessions()) + + +@mcp.shutdown +async def _shutdown() -> None: + global cleanup_task + if cleanup_task is not None: + cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await cleanup_task + cleanup_task = None + + for entry in list(sessions.values()): + try: + await entry.chat.finish() + except Exception: + pass + try: + await entry.cm.__aexit__(None, None, None) + except Exception: + pass + sessions.clear() + + await env.__aexit__(None, None, None) + + +@mcp.tool() +async def scenario_list() -> dict[str, Any]: + """List available scenarios and argument metadata.""" + scenarios = await env.list_scenarios() + return { + "scenarios": [ + { + "name": s.name, + "short_name": s.short_name, + "description": s.description, + "required_args": s.required_args, + "arguments": [ + { + "name": a.name, + "type": a.type, + "required": a.required, + "description": a.description, + "default": a.default, + } + for a in s.arguments + ], + } + for s in scenarios + ] + } + + +@mcp.tool() +async def scenario_start( + scenario: str, + scenario_args: dict[str, Any], + message: str = "Begin.", + model: str = "gpt-4o", + max_steps: int = 20, +) -> dict[str, Any]: + """Start a scenario session (required first-turn bootstrap).""" + if max_steps < 1: + raise ValueError("max_steps must be >= 1") + session_id, entry = await _start_session( + model=model, + scenario=scenario, + scenario_args=scenario_args, + max_steps=max_steps, + ) + first_turn = await entry.chat.send(message) + return { + "answer": first_turn.answer, + "tool_calls": first_turn.tool_calls, + "hud": _session_meta(session_id, entry.chat.trace_id), + } + + +@mcp.tool() +async def scenario_send(session_id: str, message: str) -> dict[str, Any]: + """Send a follow-up turn to an existing session.""" + entry = sessions.get(session_id) + if entry is None: + raise ValueError(f"Session not found: {session_id}") + entry.touch() + turn = await entry.chat.send(message) + return { + "answer": turn.answer, + "tool_calls": turn.tool_calls, + "hud": _session_meta(session_id, entry.chat.trace_id), + } + + +@mcp.tool() +async def scenario_finish(session_id: str, answer: str | None = None) -> dict[str, Any]: + """Finish a session and return reward + trace metadata.""" + entry = sessions.get(session_id) + if entry is None: + raise ValueError(f"Session not found: {session_id}") + try: + result = await entry.chat.finish(answer=answer) + finally: + sessions.pop(session_id, None) + return { + **_session_meta(session_id, result.trace_id), + "answer": result.answer, + "reward": result.reward, + } + + +if __name__ == "__main__": + if TRANSPORT == "stdio": + mcp.run(transport="stdio") + elif TRANSPORT in ("streamable-http", "sse", "http"): + mcp.run(transport=TRANSPORT, host=HOST, port=PORT) + else: + raise ValueError( + f"Unsupported MCP_TRANSPORT='{TRANSPORT}'. " + "Use one of: stdio, streamable-http, sse, http." + ) \ No newline at end of file diff --git a/hud/__init__.py b/hud/__init__.py index 1fb747b1a..b18c68072 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -9,9 +9,17 @@ # Apply patches to third-party libraries early, before other imports from . import patches as _patches # noqa: F401 -from .environment import Environment +from .environment import Environment, ScenarioArg, ScenarioInfo from .eval import EvalContext from .eval import run_eval as eval +from .scenario_chat import ( + ChatEvent, + ScenarioChatResult, + ScenarioChatSession, + ScenarioChatTurnResult, + run_scenario_chat, + run_scenario_chat_interactive, +) from .telemetry.instrument import instrument @@ -30,10 +38,18 @@ def trace(*args: object, **kwargs: object) -> EvalContext: __all__ = [ + "ChatEvent", "Environment", "EvalContext", + "ScenarioArg", + "ScenarioChatResult", + "ScenarioChatSession", + "ScenarioChatTurnResult", + "ScenarioInfo", "eval", "instrument", + "run_scenario_chat", + "run_scenario_chat_interactive", "trace", # Deprecated alias for eval ] diff --git a/hud/agent_server.py b/hud/agent_server.py new file mode 100644 index 000000000..e8b9b1228 --- /dev/null +++ b/hud/agent_server.py @@ -0,0 +1,676 @@ +"""Serve an HUD environment + scenarios as an OpenAI-compatible agent endpoint. + +Supports multi-turn conversations (streaming and non-streaming) via session IDs. +External agents connect via standard OpenAI client -- no HUD SDK needed +on the caller side. + + from hud.agent_server import serve_agent + + env = hud.Environment("my-env") + env.connect_hub("my-env") + serve_agent(env, model="gpt-4o", client=AsyncOpenAI(...), port=8000) + +Then any OpenAI client can call it: + + client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + + # First turn — creates a session + r = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Investigate this issue"}], + extra_body={"scenario": "my-scenario", + "scenario_args": {"arg": "value"}}, + ) + + # Follow-up turns — reuse the session via X-HUD-Session-Id header + session_id = r.hud["session_id"] + r2 = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "What are the root causes?"}], + extra_headers={"X-HUD-Session-Id": session_id}, + ) + + # Finish — submit and evaluate + import httpx + httpx.post(f"http://localhost:8000/v1/sessions/{session_id}/finish") +""" + +import asyncio +import contextlib +import json +import logging +import time +import uuid +from collections.abc import AsyncIterator +from typing import Any + +logger = logging.getLogger(__name__) + +DEFAULT_SESSION_TTL_SECONDS = 30 * 60 # 30 minutes +SESSION_ID_HEADER = "x-hud-session-id" + + +def _hud_meta(*, session_id: str, trace_id: str) -> dict[str, Any]: + return { + "session_id": session_id, + "thread_id": session_id, + "conversation_id": session_id, + "trace_id": trace_id, + "trace_url": f"https://hud.ai/trace/{trace_id}", + } + + +class _SessionEntry: + __slots__ = ("chat", "cm", "last_active", "session_ttl") + + def __init__(self, chat: Any, cm: Any, session_ttl: int) -> None: + self.chat = chat + self.cm = cm + self.session_ttl = session_ttl + self.last_active = time.monotonic() + + def touch(self) -> None: + self.last_active = time.monotonic() + + @property + def expired(self) -> bool: + return (time.monotonic() - self.last_active) > self.session_ttl + + +async def _cleanup_expired(sessions: dict[str, _SessionEntry]) -> None: + """Remove sessions that have been idle longer than TTL.""" + expired = [sid for sid, entry in sessions.items() if entry.expired] + for sid in expired: + entry = sessions.pop(sid) + logger.info("Session %s expired after %ds idle — cleaning up", sid, entry.session_ttl) + try: + await entry.chat.finish() + except Exception: + logger.debug("Error finishing expired session %s", sid, exc_info=True) + try: + await entry.cm.__aexit__(None, None, None) + except Exception: + logger.debug("Error closing expired session cm %s", sid, exc_info=True) + + +def serve_agent( + env: Any, + *, + client: Any, + model: str = "gpt-4o", + host: str = "0.0.0.0", # noqa: S104 + port: int = 8000, + api_key: str | None = None, + workers: int = 1, + session_ttl: int = DEFAULT_SESSION_TTL_SECONDS, +) -> None: + """Start an OpenAI-compatible HTTP server backed by HUD scenarios. + + Args: + env: An :class:`~hud.environment.Environment` with ``connect_hub()`` called. + client: An ``AsyncOpenAI`` (or compatible) client for LLM calls. + model: Default model name for completions. + host: Bind address. + port: Bind port. + api_key: Optional API key to require from callers. + workers: Number of uvicorn workers. Sessions are per-worker, so use 1 + unless you have a sticky-session reverse proxy in front. + session_ttl: Seconds of inactivity before a session is automatically + cleaned up. Defaults to 30 minutes. + """ + import uvicorn + + if session_ttl <= 0: + raise ValueError("session_ttl must be >= 1") + + app = _build_app( + env=env, + client=client, + model=model, + api_key=api_key, + session_ttl=session_ttl, + ) + uvicorn.run(app, host=host, port=port, workers=workers) + + +def _build_app( + *, + env: Any, + client: Any, + model: str, + api_key: str | None, + session_ttl: int, +) -> Any: + from contextlib import asynccontextmanager + + from fastapi import FastAPI, HTTPException, Request + from fastapi.responses import JSONResponse, StreamingResponse + from pydantic import BaseModel, Field + + sessions: dict[str, _SessionEntry] = {} + + @asynccontextmanager + async def _lifespan(_app: Any) -> AsyncIterator[None]: + async with env: + cleanup_task = asyncio.create_task(_cleanup_loop(sessions)) + try: + yield + finally: + cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await cleanup_task + for entry in sessions.values(): + try: + await entry.chat.finish() + except Exception: + logger.debug("Error finishing session during shutdown", exc_info=True) + try: + await entry.cm.__aexit__(None, None, None) + except Exception: + logger.debug("Error closing session context manager", exc_info=True) + sessions.clear() + + app = FastAPI(title="HUD Agent Server", lifespan=_lifespan) + + _default_model = model + + class ChatCompletionRequest(BaseModel): + model: str = _default_model + messages: list[dict[str, Any]] = Field(default_factory=list) + stream: bool = False + scenario: str | None = None + scenario_args: dict[str, Any] | None = None + thread_id: str | None = None + conversation_id: str | None = None + max_steps: int = Field(default=20, ge=1, le=100) + + class LifecycleToolCallRequest(BaseModel): + name: str + arguments: dict[str, Any] = Field(default_factory=dict) + session_id: str | None = None + thread_id: str | None = None + conversation_id: str | None = None + + def _extract_session_id( + raw_request: Request, + *, + session_id: str | None = None, + thread_id: str | None = None, + conversation_id: str | None = None, + ) -> str | None: + return ( + raw_request.headers.get(SESSION_ID_HEADER) + or session_id + or thread_id + or conversation_id + ) + + def _hud_meta(*, session_id: str, trace_id: str) -> dict[str, Any]: + return { + "session_id": session_id, + "thread_id": session_id, + "conversation_id": session_id, + "trace_id": trace_id, + "trace_url": f"https://hud.ai/trace/{trace_id}", + } + + def _tool_schemas() -> list[dict[str, Any]]: + return [ + { + "name": "scenario_list", + "description": "List available scenarios and required args.", + "input_schema": {"type": "object", "properties": {}, "additionalProperties": False}, + }, + { + "name": "scenario_start", + "description": "Create a session by starting a scenario with args and an optional first message.", + "input_schema": { + "type": "object", + "properties": { + "scenario": {"type": "string"}, + "scenario_args": {"type": "object"}, + "message": {"type": "string"}, + "model": {"type": "string"}, + "max_steps": {"type": "integer", "minimum": 1, "maximum": 100}, + }, + "required": ["scenario", "scenario_args"], + "additionalProperties": True, + }, + }, + { + "name": "scenario_send", + "description": "Send a follow-up user message to an existing scenario session.", + "input_schema": { + "type": "object", + "properties": { + "session_id": {"type": "string"}, + "thread_id": {"type": "string"}, + "conversation_id": {"type": "string"}, + "message": {"type": "string"}, + }, + "required": ["message"], + "additionalProperties": False, + }, + }, + { + "name": "scenario_finish", + "description": "Finish an active session and return reward/trace metadata.", + "input_schema": { + "type": "object", + "properties": { + "session_id": {"type": "string"}, + "thread_id": {"type": "string"}, + "conversation_id": {"type": "string"}, + "answer": {"type": "string"}, + }, + "additionalProperties": False, + }, + }, + ] + + async def _start_session( + *, + scenario: str, + scenario_args: dict[str, Any], + model_name: str, + max_steps: int, + ) -> tuple[str, _SessionEntry]: + from hud.scenario_chat import run_scenario_chat_interactive + + session_id = uuid.uuid4().hex[:16] + cm = run_scenario_chat_interactive( + client=client, + model=model_name, + env=env, + scenario=scenario, + args=scenario_args, + max_steps=max_steps, + ) + try: + chat = await cm.__aenter__() + except Exception: + with contextlib.suppress(Exception): + await cm.__aexit__(None, None, None) + raise + entry = _SessionEntry(chat=chat, cm=cm, session_ttl=session_ttl) + sessions[session_id] = entry + logger.info("Session %s created for scenario %s", session_id, scenario) + return session_id, entry + + async def _finish_session( + *, + session_id: str, + answer: str | None = None, + ) -> dict[str, Any]: + entry = sessions.get(session_id) + if entry is None: + raise HTTPException(404, f"Session {session_id} not found") + + try: + result = await entry.chat.finish(answer=answer) + finally: + # Always remove session from active registry; chat.finish handles cleanup. + sessions.pop(session_id, None) + + return { + "session_id": session_id, + "thread_id": session_id, + "conversation_id": session_id, + "answer": result.answer, + "reward": result.reward, + "trace_id": result.trace_id, + "trace_url": f"https://hud.ai/trace/{result.trace_id}", + } + + async def _send_turn( + *, + session_id: str, + message: str, + ) -> dict[str, Any]: + entry = sessions.get(session_id) + if entry is None: + raise HTTPException(404, f"Session {session_id} not found") + entry.touch() + turn = await entry.chat.send(message) + return { + "session_id": session_id, + "thread_id": session_id, + "conversation_id": session_id, + "answer": turn.answer, + "trace_id": entry.chat.trace_id, + "trace_url": f"https://hud.ai/trace/{entry.chat.trace_id}", + "tool_calls": turn.tool_calls, + } + + @app.middleware("http") + async def _check_auth(request: Request, call_next: Any) -> Any: + if api_key and request.url.path not in ("/", "/health"): + auth = request.headers.get("authorization", "") + if not auth.startswith("Bearer ") or auth[7:] != api_key: + return JSONResponse( + status_code=401, + content={"error": "Invalid or missing API key"}, + ) + return await call_next(request) + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/scenarios") + async def list_scenarios_route() -> dict[str, Any]: + scenarios = await env.list_scenarios() + return { + "scenarios": [ + { + "name": s.name, + "short_name": s.short_name, + "description": s.description, + "required_args": s.required_args, + "arguments": [ + { + "name": a.name, + "type": a.type, + "required": a.required, + "description": a.description, + "default": a.default, + } + for a in s.arguments + ], + } + for s in scenarios + ] + } + + @app.get("/v1/models") + async def list_models() -> dict[str, Any]: + return { + "object": "list", + "data": [{"id": model, "object": "model", "owned_by": "hud"}], + } + + @app.get("/v1/lifecycle-tools") + async def lifecycle_tools() -> dict[str, Any]: + return {"tools": _tool_schemas()} + + @app.post("/v1/lifecycle-tools/call") + async def lifecycle_tools_call(request: LifecycleToolCallRequest, raw: Request) -> dict[str, Any]: + tool_name = request.name + args = request.arguments + + if tool_name == "scenario_list": + return await list_scenarios_route() + + if tool_name == "scenario_start": + scenario = args.get("scenario") + scenario_args = args.get("scenario_args") + if not isinstance(scenario, str) or not scenario: + raise HTTPException(400, "scenario_start requires 'scenario' (string)") + if not isinstance(scenario_args, dict): + raise HTTPException(400, "scenario_start requires 'scenario_args' (object)") + message = str(args.get("message", "Begin.")) + max_steps = int(args.get("max_steps", 20)) + model_name = str(args.get("model", model)) + + new_session_id, entry = await _start_session( + scenario=scenario, + scenario_args=scenario_args, + model_name=model_name, + max_steps=max_steps, + ) + turn = await entry.chat.send(message) + return { + "answer": turn.answer, + "tool_calls": turn.tool_calls, + "hud": _hud_meta(session_id=new_session_id, trace_id=entry.chat.trace_id), + } + + if tool_name == "scenario_send": + session_id = _extract_session_id( + raw, + session_id=args.get("session_id"), + thread_id=args.get("thread_id"), + conversation_id=args.get("conversation_id"), + ) + if not session_id: + raise HTTPException(400, "scenario_send requires session_id/thread_id/conversation_id") + message = str(args.get("message", "")) + if not message: + raise HTTPException(400, "scenario_send requires a non-empty 'message'") + result = await _send_turn(session_id=session_id, message=message) + return { + "answer": result["answer"], + "tool_calls": result["tool_calls"], + "hud": _hud_meta(session_id=session_id, trace_id=result["trace_id"]), + } + + if tool_name == "scenario_finish": + session_id = _extract_session_id( + raw, + session_id=args.get("session_id"), + thread_id=args.get("thread_id"), + conversation_id=args.get("conversation_id"), + ) + if not session_id: + raise HTTPException(400, "scenario_finish requires session_id/thread_id/conversation_id") + answer = args.get("answer") + if answer is not None and not isinstance(answer, str): + raise HTTPException(400, "scenario_finish 'answer' must be a string when provided") + return await _finish_session(session_id=session_id, answer=answer) + + raise HTTPException(400, f"Unknown lifecycle tool: {tool_name}") + + @app.get("/mcp/tools") + async def mcp_tools() -> dict[str, Any]: + return {"tools": _tool_schemas()} + + @app.post("/mcp/tools/call") + async def mcp_tools_call(request: LifecycleToolCallRequest, raw: Request) -> dict[str, Any]: + # MCP-native surface maps to the same lifecycle tool handlers/runtime. + return await lifecycle_tools_call(request, raw) + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest, raw: Request) -> Any: + session_id = _extract_session_id( + raw, + thread_id=request.thread_id, + conversation_id=request.conversation_id, + ) + + last_user_msg = "" + for msg in reversed(request.messages): + if msg.get("role") == "user": + last_user_msg = msg.get("content", "") + break + if not last_user_msg: + last_user_msg = "Begin." + + completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" + created = int(time.time()) + + if session_id and session_id in sessions: + entry = sessions[session_id] + entry.touch() + + if request.stream: + return StreamingResponse( + _stream_sse( + chat=entry.chat, + message=last_user_msg, + completion_id=completion_id, + created=created, + model_name=request.model, + session_id=session_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-HUD-Session-Id": session_id, + "X-HUD-Thread-Id": session_id, + "X-HUD-Trace-Id": entry.chat.trace_id, + }, + ) + + turn = await entry.chat.send(last_user_msg) + return _completion_response( + completion_id=completion_id, + created=created, + model_name=request.model, + content=turn.answer, + session_id=session_id, + trace_id=entry.chat.trace_id, + ) + if session_id: + raise HTTPException(404, f"Session {session_id} not found") + + if not request.scenario: + raise HTTPException(400, "scenario is required for the first turn") + if request.scenario_args is None: + raise HTTPException(400, "scenario_args is required for the first turn") + + session_id, entry = await _start_session( + scenario=request.scenario, + scenario_args=request.scenario_args, + model_name=request.model, + max_steps=request.max_steps, + ) + chat = entry.chat + + if request.stream: + return StreamingResponse( + _stream_sse( + chat=chat, + message=last_user_msg, + completion_id=completion_id, + created=created, + model_name=request.model, + session_id=session_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-HUD-Session-Id": session_id, + "X-HUD-Thread-Id": session_id, + "X-HUD-Trace-Id": chat.trace_id, + }, + ) + + turn = await chat.send(last_user_msg) + return _completion_response( + completion_id=completion_id, + created=created, + model_name=request.model, + content=turn.answer, + session_id=session_id, + trace_id=chat.trace_id, + ) + + @app.post("/v1/sessions/{session_id}/finish") + async def finish_session_route(session_id: str) -> dict[str, Any]: + return await _finish_session(session_id=session_id) + + @app.get("/v1/sessions") + async def list_sessions() -> dict[str, Any]: + return { + "sessions": [ + { + "session_id": sid, + "thread_id": sid, + "conversation_id": sid, + "trace_id": entry.chat.trace_id, + "idle_seconds": int(time.monotonic() - entry.last_active), + } + for sid, entry in sessions.items() + ] + } + + return app + + +async def _cleanup_loop(sessions: dict[str, _SessionEntry]) -> None: + """Periodically evict expired sessions.""" + while True: + await asyncio.sleep(60) + try: + await _cleanup_expired(sessions) + except Exception: + logger.debug("Session cleanup error", exc_info=True) + + +async def _stream_sse( + *, + chat: Any, + message: str, + completion_id: str, + created: int, + model_name: str, + session_id: str, +) -> AsyncIterator[str]: + """Yield OpenAI-format SSE chunks from a streaming scenario chat turn.""" + + def _chunk( + content: str | None = None, + finish_reason: str | None = None, + ) -> str: + delta: dict[str, str] = {} + if content is not None: + delta["content"] = content + data = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + } + return f"data: {json.dumps(data)}\n\n" + + yield _chunk(content="") + + async for event in chat.send_stream(message): + if event.type == "text_delta": + yield _chunk(content=event.content) + + yield _chunk(finish_reason="stop") + + meta = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [], + "hud": { + "session_id": session_id, + "trace_id": chat.trace_id, + "trace_url": f"https://hud.ai/trace/{chat.trace_id}", + }, + } + yield f"data: {json.dumps(meta)}\n\n" + yield "data: [DONE]\n\n" + + +def _completion_response( + *, + completion_id: str, + created: int, + model_name: str, + content: str, + session_id: str, + trace_id: str, +) -> dict[str, Any]: + return { + "id": completion_id, + "object": "chat.completion", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "hud": { + **_hud_meta(session_id=session_id, trace_id=trace_id), + }, + } diff --git a/hud/environment/__init__.py b/hud/environment/__init__.py index 731f18d1c..331875404 100644 --- a/hud/environment/__init__.py +++ b/hud/environment/__init__.py @@ -29,7 +29,7 @@ from hud.environment.mock import MockMixin, generate_mock_value from hud.environment.router import ConflictResolution, MCPRouter, ToolRouter from hud.environment.scenarios import ScenarioHandle, ScenarioMixin, ScenarioSession -from hud.environment.types import EnvConfig +from hud.environment.types import EnvConfig, ScenarioArg, ScenarioInfo from hud.environment.utils import ToolFormat, format_result, parse_tool_call, parse_tool_calls __all__ = [ @@ -39,6 +39,8 @@ "Connector", "EnvConfig", "Environment", + "ScenarioArg", + "ScenarioInfo", "MCPRouter", "MockMixin", "ScenarioHandle", diff --git a/hud/environment/connection.py b/hud/environment/connection.py index 99fb50250..527b204e1 100644 --- a/hud/environment/connection.py +++ b/hud/environment/connection.py @@ -119,6 +119,16 @@ async def connect(self) -> None: 2. httpx auto-instrumentation can inject trace headers """ from fastmcp.client import Client as FastMCPClient + import uuid + + # Hub transports carry an Environment-Id header that pins requests to a pod. + # Refresh it on each connect so sessions don't reconnect to cleaned-up pods. + if ( + hasattr(self._transport, "headers") + and isinstance(self._transport.headers, dict) + and "Environment-Id" in self._transport.headers + ): + self._transport.headers["Environment-Id"] = str(uuid.uuid4()) self.client = FastMCPClient( transport=self._transport, diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 02fa99f9a..15113e2ae 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -14,6 +14,7 @@ from hud.environment.mock import MockMixin from hud.environment.router import ConflictResolution, ToolRouter from hud.environment.scenarios import ScenarioMixin +from hud.environment.types import ScenarioArg, ScenarioInfo from hud.server.server import MCPServer from hud.types import MCPToolResult @@ -697,6 +698,65 @@ async def get_prompt( raise ValueError(f"Connection '{conn_name}' not found for prompt '{name}'") return await conn.get_prompt(name, arguments) + async def list_scenarios(self) -> list[ScenarioInfo]: + """Return structured metadata for all available scenarios. + + Each :class:`ScenarioInfo` includes the scenario name, description, + and typed argument schemas -- everything needed to build an A2A + Agent Card or OpenAI tool definition. + + Requires a prior :meth:`connect_hub` or local ``@env.scenario`` + registration so that prompts are discoverable. + """ + prompts = await self.list_prompts() + scenarios: list[ScenarioInfo] = [] + for prompt in prompts: + short_name = prompt.name.split(":", 1)[-1] if ":" in prompt.name else prompt.name + description = prompt.description or "" + if description.startswith("[Setup] "): + description = description[8:] + + args: list[ScenarioArg] = [] + meta = getattr(prompt, "meta", None) + meta_arguments = meta.get("arguments") if isinstance(meta, dict) else None + + if isinstance(meta_arguments, list): + for arg_meta in meta_arguments: + if not isinstance(arg_meta, dict): + continue + name = arg_meta.get("name") + if not name: + continue + args.append( + ScenarioArg( + name=name, + type=arg_meta.get("type", "string"), + required=arg_meta.get("required", True), + description=arg_meta.get("description"), + default=arg_meta.get("default"), + ) + ) + elif prompt.arguments: + for prompt_arg in prompt.arguments: + args.append( + ScenarioArg( + name=prompt_arg.name, + type="string", + required=prompt_arg.required if prompt_arg.required is not None else True, + description=prompt_arg.description, + ) + ) + + scenarios.append( + ScenarioInfo( + name=prompt.name, + short_name=short_name, + description=description or None, + arguments=args, + ) + ) + return scenarios + # ========================================================================= # Server Methods # ========================================================================= @@ -711,6 +771,44 @@ def serve( """Start serving as an MCP server.""" self.run(transport=transport, host=host, port=port, **kwargs) + def serve_as_agent( + self, + *, + client: Any, + model: str = "gpt-4o", + host: str = "0.0.0.0", # noqa: S104 + port: int = 8000, + api_key: str | None = None, + ) -> None: + """Start an OpenAI-compatible agent endpoint backed by this environment's scenarios. + + External agents connect with a standard OpenAI client -- no HUD SDK + needed on the caller side:: + + # Server + env = Environment("my-env") + env.connect_hub("my-env") + env.serve_as_agent(client=AsyncOpenAI(...), port=8000) + + # External agent (any language, any framework) + client = OpenAI(base_url="http://host:8000") + client.chat.completions.create( + model="gpt-4o", + messages=[...], + extra_body={"scenario": "investigate", "scenario_args": {...}}, + ) + + Args: + client: ``AsyncOpenAI`` (or compatible) for LLM calls. + model: Default model identifier. + host: Bind address. + port: Bind port. + api_key: Optional bearer token to require from callers. + """ + from hud.agent_server import serve_agent + + serve_agent(self, client=client, model=model, host=host, port=port, api_key=api_key) + # ========================================================================= # Properties # ========================================================================= diff --git a/hud/environment/tests/test_connection.py b/hud/environment/tests/test_connection.py index 1d24b4090..75ce7d71c 100644 --- a/hud/environment/tests/test_connection.py +++ b/hud/environment/tests/test_connection.py @@ -140,6 +140,48 @@ async def test_connect_creates_client(self) -> None: # Client is now set assert connector.client is mock_client + def test_copy_preserves_environment_id_header(self) -> None: + """copy() does not rotate Environment-Id headers.""" + transport = MagicMock() + transport.headers = { + "Environment-Name": "browser", + "Environment-Id": "fixed-id", + } + connector = Connector( + transport=transport, + config=ConnectionConfig(), + name="hub", + connection_type=ConnectionType.REMOTE, + ) + + copied = connector.copy() + assert copied._transport.headers["Environment-Id"] == "fixed-id" + + @pytest.mark.asyncio + async def test_connect_rotates_environment_id_header(self) -> None: + """connect() rotates Environment-Id before creating the client.""" + transport = MagicMock() + transport.headers = { + "Environment-Name": "browser", + "Environment-Id": "old-id", + } + connector = Connector( + transport=transport, + config=ConnectionConfig(), + name="hub", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.is_connected = MagicMock(return_value=True) + + with patch("fastmcp.client.Client", return_value=mock_client): + await connector.connect() + + assert transport.headers["Environment-Id"] != "old-id" + assert connector.client is mock_client + @pytest.mark.asyncio async def test_disconnect_clears_client(self) -> None: """disconnect() closes client and clears state.""" diff --git a/hud/environment/tests/test_scenarios.py b/hud/environment/tests/test_scenarios.py index dc91dc33a..e50649751 100644 --- a/hud/environment/tests/test_scenarios.py +++ b/hud/environment/tests/test_scenarios.py @@ -1866,3 +1866,102 @@ async def headless(): assert prompt.meta is not None assert prompt.meta.get("exclude_tools") == ["browser_*"] assert prompt.meta.get("exclude_sources") == ["hub"] + + +class TestListScenarios: + """Tests for Environment.list_scenarios().""" + + @pytest.mark.asyncio + async def test_list_scenarios_basic(self) -> None: + """list_scenarios returns ScenarioInfo with name and args.""" + env = Environment("test-env") + + @env.scenario("analyze") + async def analyze(item_id: str, description: str = "unknown"): + yield f"Analyzing {item_id}" + yield 1.0 + + async with env: + scenarios = await env.list_scenarios() + + assert len(scenarios) == 1 + s = scenarios[0] + assert s.short_name == "analyze" + assert s.name == "test-env:analyze" + assert len(s.arguments) == 2 + + item_arg = next(a for a in s.arguments if a.name == "item_id") + assert item_arg.required is True + assert item_arg.type == "string" + + desc_arg = next(a for a in s.arguments if a.name == "description") + assert desc_arg.required is False + + @pytest.mark.asyncio + async def test_list_scenarios_required_args(self) -> None: + """required_args property returns only required argument names.""" + env = Environment("test-env") + + @env.scenario("checkout") + async def checkout(user_id: str, coupon: str = ""): + yield "Go" + yield 1.0 + + async with env: + scenarios = await env.list_scenarios() + + assert scenarios[0].required_args == ["user_id"] + + @pytest.mark.asyncio + async def test_list_scenarios_to_openai_tool(self) -> None: + """to_openai_tool produces valid OpenAI function tool definition.""" + env = Environment("test-env") + + @env.scenario("search") + async def search(query: str): + yield f"Searching {query}" + yield 1.0 + + async with env: + scenarios = await env.list_scenarios() + + tool_def = scenarios[0].to_openai_tool() + assert tool_def["type"] == "function" + assert tool_def["function"]["name"] == "search" + params = tool_def["function"]["parameters"] + assert "query" in params["properties"] + assert "query" in params["required"] + + @pytest.mark.asyncio + async def test_list_scenarios_empty(self) -> None: + """list_scenarios returns empty list when no scenarios registered.""" + env = Environment("test-env") + + async with env: + scenarios = await env.list_scenarios() + + assert scenarios == [] + + @pytest.mark.asyncio + async def test_list_scenarios_skips_malformed_meta_arguments(self) -> None: + """list_scenarios should skip malformed metadata argument entries.""" + env = Environment("test-env") + + @env.scenario("demo") + async def demo_scenario(valid: str): + yield "Prompt" + yield 1.0 + + prompt = env._prompt_manager._prompts["test-env:demo"] + assert prompt.meta is not None + prompt.meta["arguments"] = [ + {"name": "valid", "type": "string", "required": True}, + {"type": "string", "required": True}, + "bad-entry", + ] + + async with env: + scenarios = await env.list_scenarios() + + assert len(scenarios) == 1 + assert [arg.name for arg in scenarios[0].arguments] == ["valid"] diff --git a/hud/environment/types.py b/hud/environment/types.py index fca74c7c8..db9a901ee 100644 --- a/hud/environment/types.py +++ b/hud/environment/types.py @@ -2,9 +2,12 @@ from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any + from pydantic import BaseModel, Field -__all__ = ["EnvConfig"] +__all__ = ["EnvConfig", "ScenarioArg", "ScenarioInfo"] class EnvConfig(BaseModel): @@ -21,3 +24,53 @@ class EnvConfig(BaseModel): name: str = Field(description="Hub name to connect to") include: list[str] | None = Field(default=None, description="Whitelist of tool names") exclude: list[str] | None = Field(default=None, description="Blacklist of tool names") + + +@dataclass +class ScenarioArg: + """Metadata for a single scenario argument.""" + + name: str + type: str = "string" + required: bool = True + description: str | None = None + default: Any = None + + +@dataclass +class ScenarioInfo: + """Structured metadata for a scenario, suitable for Agent Cards or tool definitions.""" + + name: str + short_name: str + description: str | None = None + arguments: list[ScenarioArg] = field(default_factory=list) + + @property + def required_args(self) -> list[str]: + return [a.name for a in self.arguments if a.required] + + def to_openai_tool(self) -> dict[str, Any]: + """Convert to an OpenAI function tool definition.""" + properties: dict[str, Any] = {} + required: list[str] = [] + for arg in self.arguments: + prop: dict[str, Any] = {"type": arg.type} + if arg.description: + prop["description"] = arg.description + properties[arg.name] = prop + if arg.required: + required.append(arg.name) + + return { + "type": "function", + "function": { + "name": self.short_name, + "description": self.description or f"Run scenario {self.short_name}", + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } diff --git a/hud/eval/context.py b/hud/eval/context.py index 5884b0638..9d9e9b8b6 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -13,12 +13,14 @@ import logging import uuid from contextlib import contextmanager +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Self from hud.environment import Environment from hud.settings import settings from hud.shared import make_request from hud.telemetry import flush, instrument +from hud.telemetry.exporter import queue_span if TYPE_CHECKING: from collections.abc import Generator @@ -85,6 +87,27 @@ def get_current_api_key() -> str | None: return _current_api_key.get() +def _now_iso() -> str: + return datetime.now(UTC).isoformat().replace("+00:00", "Z") + + +def _normalize_trace_id(trace_id: str) -> str: + return trace_id.replace("-", "")[:32].ljust(32, "0") + + +def _serialize_span_value(value: Any) -> Any: + if value is None or isinstance(value, dict | list | str | int | float | bool): + return value + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump(mode="json", exclude_none=True) + if isinstance(dumped, dict | list | str | int | float | bool): + return dumped + except Exception: + pass + return str(value) + + # ============================================================================= # EvalContext # ============================================================================= @@ -700,12 +723,14 @@ async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolRe """Execute a tool with automatic telemetry recording. Overrides Environment._execute_tool to record MCP spans for the eval context. - Instrumentation is disabled when connected to a remote HUD server (telemetry is - recorded server-side in that case). + Hub-connected envs emit canonical spans here (remote server doesn't emit them). + Local envs use @instrument. V4 HUD server envs skip (server-side telemetry). """ - # Skip instrumentation when connected to a remote hub - telemetry is handled server-side - if self._hub_config is not None: - return await super()._execute_tool(name, arguments) + # Hub-connected environments: optionally emit canonical mcp.tool_call spans + # (the remote MCP server doesn't emit these for the client's trace). + # Gated behind a feature flag for safe rollout. + if self._hub_config is not None and settings.canonical_hub_mcp_tool_spans_enabled: + return await self._execute_tool_hub(name, arguments) # Skip instrumentation for v4 tasks with HUD MCP config (remote server) if self._mcp_config is not None: @@ -717,9 +742,56 @@ async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolRe if url and _is_hud_server(url): return await super()._execute_tool(name, arguments) - # For local environments, record MCP spans + # Local environments: existing @instrument decorator return await self._execute_tool_instrumented(name, arguments) + async def _execute_tool_hub( + self, name: str, arguments: dict[str, Any] + ) -> MCPToolResult: + """Hub-connected tool execution with canonical span emission.""" + start_time = _now_iso() + result: MCPToolResult | None = None + error_message: str | None = None + + try: + result = await super()._execute_tool(name, arguments) + return result + except Exception as e: + error_message = f"{type(e).__name__}: {e}" + raise + finally: + if self._trace_enabled and self.trace_id: + end_time = _now_iso() + span: dict[str, Any] = { + "name": "mcp.tool_call", + "trace_id": _normalize_trace_id(self.trace_id), + "span_id": uuid.uuid4().hex[:16], + "parent_span_id": None, + "start_time": start_time, + "end_time": end_time, + "status_code": "ERROR" if error_message else "OK", + "status_message": error_message, + "internal_type": "mcp-tool", + "attributes": { + "task_run_id": self.trace_id, + "job_id": self.job_id, + "category": "mcp", + "type": "CLIENT", + # Use tools/call envelope for compatibility with existing trace UI + # while preserving canonical span name/internal_type. + "request": { + "method": "tools/call", + "name": name, + "params": {"name": name, "arguments": arguments}, + }, + "result": _serialize_span_value(result), + "start_timestamp": start_time, + "end_timestamp": end_time, + }, + "exceptions": [{"message": error_message}] if error_message else None, + } + queue_span(span) + @instrument(category="mcp") async def _execute_tool_instrumented( self, name: str, arguments: dict[str, Any] diff --git a/hud/scenario_chat.py b/hud/scenario_chat.py new file mode 100644 index 000000000..cdc06656f --- /dev/null +++ b/hud/scenario_chat.py @@ -0,0 +1,716 @@ +"""High-level SDK helper for running a scenario as chat.""" + +from __future__ import annotations + +import contextlib +import json +import logging +import uuid +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any, Literal, cast + +from hud.environment import Environment +from hud.eval import Task +from hud.eval.manager import run_eval +from hud.tools.types import EvaluationResult + +try: + from openai import AsyncOpenAI +except ImportError: # pragma: no cover - import-time optional dependency + AsyncOpenAI = Any # type: ignore[misc,assignment] + +ScenarioChatApi = Literal["chat_completions", "responses", "auto"] +logger = logging.getLogger(__name__) + + +def _now_iso() -> str: + return datetime.now(UTC).isoformat().replace("+00:00", "Z") + + +def _normalize_trace_id(trace_id: str) -> str: + return trace_id.replace("-", "")[:32].ljust(32, "0") + + +def _make_user_message_span(trace_id: str, message: str) -> dict[str, Any]: + """Build a telemetry span for a user message.""" + now = _now_iso() + return { + "name": "user.message", + "trace_id": _normalize_trace_id(trace_id), + "span_id": uuid.uuid4().hex[:16], + "parent_span_id": None, + "start_time": now, + "end_time": now, + "status_code": "OK", + "status_message": None, + "internal_type": "user-message", + "attributes": { + "task_run_id": trace_id, + "category": "agent", + "type": "CLIENT", + "request": {"method": "user/message", "params": {"content": message}}, + "start_timestamp": now, + "end_timestamp": now, + }, + } + + +@dataclass +class ScenarioChatResult: + """Result returned by :func:`run_scenario_chat`.""" + + answer: str + reward: float | None + evaluation_result: EvaluationResult | None + messages: list[dict[str, Any]] = field(default_factory=list) + tool_calls: list[dict[str, Any]] = field(default_factory=list) + trace_id: str = "" + + +@dataclass +class ScenarioChatTurnResult: + """Result of a single interactive turn.""" + + answer: str + tool_calls: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class ChatEvent: + """A single event yielded by :meth:`ScenarioChatSession.send_stream`. + + Event types: + + - ``text_delta``: Incremental text token from the model. + - ``tool_call``: A tool invocation (name + args resolved after streaming). + - ``tool_result``: The result returned by the tool. + - ``turn_complete``: Signals the turn is done; ``content`` holds the full answer. + """ + + type: Literal["text_delta", "tool_call", "tool_result", "turn_complete"] + content: str = "" + tool_name: str | None = None + tool_args: dict[str, Any] | None = None + tool_call_id: str | None = None + + +class ScenarioChatSession(AbstractAsyncContextManager["ScenarioChatSession"]): + """Interactive scenario chat session. + + Use this with ``async with`` and call ``send()`` for each user turn. + Call ``finish()`` when done to submit and evaluate the scenario. + """ + + def __init__( + self, + *, + client: AsyncOpenAI, + model: str, + task: Task, + api: ScenarioChatApi, + max_steps: int, + system_prompt: str | None, + trace: bool, + api_key: str | None, + completion_kwargs: dict[str, Any] | None, + ) -> None: + self.client = client + self.model = model + self.task = task + self.api: Literal["chat_completions", "responses"] = ( + "chat_completions" if api == "auto" else api + ) + self.max_steps = max_steps + self.system_prompt = system_prompt + self.trace = trace + self.api_key = api_key + self.completion_kwargs = completion_kwargs or {} + + self.messages: list[dict[str, Any]] = [] + self.tool_calls: list[dict[str, Any]] = [] + self.last_answer: str = "" + self.trace_id: str = "" + + self._ctx: Any = None + self._eval_cm: Any = None + self._entered = False + self._finished = False + self._previous_response_id: str | None = None + self._scenario_prompt: str = "" + + async def __aenter__(self) -> "ScenarioChatSession": + self._eval_cm = run_eval(self.task, trace=self.trace, api_key=self.api_key) + self._ctx = await self._eval_cm.__aenter__() + self.trace_id = self._ctx.trace_id + + if not self._ctx.prompt: + raise ValueError( + f"Scenario '{self.task.scenario}' returned an empty prompt. " + "Ensure setup yields a non-empty instruction." + ) + self._scenario_prompt = cast("str", self._ctx.prompt) + + self.system_prompt = self.system_prompt or self._ctx.system_prompt + if self.system_prompt: + self.messages.append({"role": "system", "content": self.system_prompt}) + self.messages.append({"role": "user", "content": self._scenario_prompt}) + self._entered = True + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + submit_error: Exception | None = None + if self._entered and not self._finished: + try: + await self._ctx.submit(self.last_answer or "") + self._finished = True + except Exception as e: + submit_error = e + if self._eval_cm is not None: + await self._eval_cm.__aexit__(exc_type, exc, tb) + if submit_error is not None: + raise submit_error + return False + + async def send(self, message: str) -> ScenarioChatTurnResult: + """Send one user message and run tool loop until assistant replies.""" + if not self._entered: + raise RuntimeError("Session is not active. Use 'async with' first.") + if self._finished: + raise RuntimeError("Session already finished.") + + if self.api == "responses": + return await self._send_responses(message) + return await self._send_chat_completions(message) + + async def finish(self, answer: str | None = None) -> ScenarioChatResult: + """Submit final answer and evaluate scenario.""" + if not self._entered: + raise RuntimeError("Session is not active. Use 'async with' first.") + if self._finished: + return ScenarioChatResult( + answer=self.last_answer, + reward=self._ctx.reward, + evaluation_result=self._ctx.evaluation_result, + messages=self.messages, + tool_calls=self.tool_calls, + trace_id=self.trace_id, + ) + + final_answer = answer if answer is not None else self.last_answer + self.last_answer = final_answer + submit_error: Exception | None = None + try: + await self._ctx.submit(final_answer) + self._finished = True + except Exception as e: + submit_error = e + finally: + if self._eval_cm is not None: + # Suppress ContextVar token errors from different async contexts (server sessions) + with contextlib.suppress(ValueError): + await self._eval_cm.__aexit__(None, None, None) + if submit_error is not None: + raise submit_error + + return ScenarioChatResult( + answer=final_answer, + reward=self._ctx.reward, + evaluation_result=self._ctx.evaluation_result, + messages=self.messages, + tool_calls=self.tool_calls, + trace_id=self.trace_id, + ) + + def _merged_extra_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + extra_headers = self.completion_kwargs.get("extra_headers") + if isinstance(extra_headers, dict): + headers.update({str(k): str(v) for k, v in extra_headers.items()}) + if self.trace_id: + headers["Trace-Id"] = self.trace_id + return headers + + def _chat_request_kwargs( + self, *, messages: list[dict[str, Any]], stream: bool = False + ) -> dict[str, Any]: + request_kwargs = dict(self.completion_kwargs) + request_kwargs.pop("extra_headers", None) + request_kwargs.update( + { + "model": self.model, + "messages": messages, + "tools": self._ctx.as_openai_chat_tools(), + } + ) + if stream: + request_kwargs["stream"] = True + merged_headers = self._merged_extra_headers() + if merged_headers: + request_kwargs["extra_headers"] = merged_headers + return request_kwargs + + async def _send_chat_completions(self, message: str) -> ScenarioChatTurnResult: + from hud.telemetry.exporter import queue_span + + full_messages = list(self.messages) + user_msg = {"role": "user", "content": message} + full_messages.append(user_msg) + self.messages.append(user_msg) + tool_calls_this_turn: list[dict[str, Any]] = [] + + if self.trace and self.trace_id: + queue_span(_make_user_message_span(self.trace_id, message)) + + for _ in range(self.max_steps): + response = await self.client.chat.completions.create( + **self._chat_request_kwargs(messages=full_messages) + ) + msg = response.choices[0].message + assistant_msg: dict[str, Any] = {"role": "assistant", "content": msg.content or ""} + if msg.tool_calls: + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + } + for tc in msg.tool_calls + ] + full_messages.append(assistant_msg) + self.messages.append(assistant_msg) + + if not msg.tool_calls: + self.last_answer = msg.content or self.last_answer + return ScenarioChatTurnResult( + answer=msg.content or "", + tool_calls=tool_calls_this_turn, + ) + + for tool_call in msg.tool_calls: + record = { + "id": getattr(tool_call, "id", None), + "name": getattr(getattr(tool_call, "function", None), "name", None), + "arguments": _parse_tool_arguments( + getattr(getattr(tool_call, "function", None), "arguments", "{}") + ), + } + self.tool_calls.append(record) + tool_calls_this_turn.append(record) + tool_result = await self._ctx.call_tool(tool_call) + normalized = _normalize_message(tool_result) + full_messages.append(normalized) + self.messages.append(normalized) + + return ScenarioChatTurnResult(answer=self.last_answer, tool_calls=tool_calls_this_turn) + + async def send_stream(self, message: str) -> AsyncIterator[ChatEvent]: + """Send one user message and yield events as the model responds. + + Yields :class:`ChatEvent` objects: + + - ``text_delta`` for each streamed token + - ``tool_call`` when a tool invocation is fully resolved + - ``tool_result`` after tool execution completes + - ``turn_complete`` when the model finishes its response + + For the ``responses`` API, streaming falls back to yielding the + complete response as a single ``text_delta`` followed by ``turn_complete``. + """ + if not self._entered: + raise RuntimeError("Session is not active. Use 'async with' first.") + if self._finished: + raise RuntimeError("Session already finished.") + + if self.api == "responses": + result = await self._send_responses(message) + if result.answer: + yield ChatEvent(type="text_delta", content=result.answer) + for tc in result.tool_calls: + yield ChatEvent( + type="tool_call", + tool_name=tc.get("name"), + tool_args=tc.get("arguments") if isinstance(tc.get("arguments"), dict) + else None, + tool_call_id=tc.get("id"), + ) + yield ChatEvent(type="turn_complete", content=result.answer) + return + + async for event in self._stream_chat_completions(message): + yield event + + async def _stream_chat_completions(self, message: str) -> AsyncIterator[ChatEvent]: + """Streaming implementation of the chat completions tool loop.""" + from hud.telemetry.exporter import queue_span + + full_messages = list(self.messages) + user_msg = {"role": "user", "content": message} + full_messages.append(user_msg) + self.messages.append(user_msg) + if self.trace and self.trace_id: + queue_span(_make_user_message_span(self.trace_id, message)) + + for _ in range(self.max_steps): + response = await self.client.chat.completions.create( + **self._chat_request_kwargs(messages=full_messages, stream=True) + ) + + content_parts: list[str] = [] + pending_tool_calls: dict[int, dict[str, str]] = {} + + async for chunk in response: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + + if delta.content: + content_parts.append(delta.content) + yield ChatEvent(type="text_delta", content=delta.content) + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in pending_tool_calls: + pending_tool_calls[idx] = { + "id": "", + "name": "", + "arguments": "", + } + entry = pending_tool_calls[idx] + if tc_delta.id: + entry["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + entry["name"] += tc_delta.function.name + if tc_delta.function.arguments: + entry["arguments"] += tc_delta.function.arguments + + content = "".join(content_parts) + assistant_msg: dict[str, Any] = {"role": "assistant", "content": content} + + if pending_tool_calls: + assistant_msg["tool_calls"] = [ + { + "id": pending_tool_calls[idx]["id"], + "type": "function", + "function": { + "name": pending_tool_calls[idx]["name"], + "arguments": pending_tool_calls[idx]["arguments"], + }, + } + for idx in sorted(pending_tool_calls) + ] + + full_messages.append(assistant_msg) + self.messages.append(assistant_msg) + + if not pending_tool_calls: + self.last_answer = content or self.last_answer + yield ChatEvent(type="turn_complete", content=content) + return + + for idx in sorted(pending_tool_calls): + tc = pending_tool_calls[idx] + parsed_args = _parse_tool_arguments(tc["arguments"]) + record = {"id": tc["id"], "name": tc["name"], "arguments": parsed_args} + self.tool_calls.append(record) + + yield ChatEvent( + type="tool_call", + tool_name=tc["name"], + tool_args=parsed_args if isinstance(parsed_args, dict) else None, + tool_call_id=tc["id"], + ) + + openai_tc = { + "id": tc["id"], + "type": "function", + "function": {"name": tc["name"], "arguments": tc["arguments"]}, + } + tool_result = await self._ctx.call_tool(openai_tc) + normalized = _normalize_message(tool_result) + full_messages.append(normalized) + self.messages.append(normalized) + + result_content = normalized.get("content", str(tool_result)) + yield ChatEvent( + type="tool_result", + content=result_content, + tool_name=tc["name"], + tool_call_id=tc["id"], + ) + + yield ChatEvent(type="turn_complete", content=self.last_answer) + + async def _send_responses(self, message: str) -> ScenarioChatTurnResult: + from hud.telemetry.exporter import queue_span + + self.messages.append({"role": "user", "content": message}) + response_input: Any = message + tool_calls_this_turn: list[dict[str, Any]] = [] + if not self._previous_response_id: + response_input = [ + {"role": "user", "content": self._scenario_prompt}, + {"role": "user", "content": message}, + ] + + if self.trace and self.trace_id: + queue_span(_make_user_message_span(self.trace_id, message)) + + for _ in range(self.max_steps): + request_kwargs: dict[str, Any] = dict(self.completion_kwargs) + request_kwargs.pop("extra_headers", None) + request_kwargs.update( + { + "model": self.model, + "input": response_input, + "tools": self._ctx.as_openai_responses_tools(), + } + ) + if self.system_prompt: + request_kwargs.setdefault("instructions", self.system_prompt) + if self._previous_response_id: + request_kwargs["previous_response_id"] = self._previous_response_id + + merged_headers = self._merged_extra_headers() + if merged_headers: + request_kwargs["extra_headers"] = merged_headers + response = await self.client.responses.create(**request_kwargs) + self._previous_response_id = cast("str | None", getattr(response, "id", None)) + output_text = cast("str", getattr(response, "output_text", "") or "") + self.messages.append( + { + "role": "assistant", + "content": output_text, + "response_id": self._previous_response_id, + } + ) + + function_calls = [ + item + for item in (getattr(response, "output", []) or []) + if getattr(item, "type", None) == "function_call" + ] + if not function_calls: + self.last_answer = output_text or self.last_answer + return ScenarioChatTurnResult(answer=output_text, tool_calls=tool_calls_this_turn) + + tool_outputs: list[Any] = [] + for function_call in function_calls: + record = { + "id": getattr(function_call, "id", None), + "name": getattr(function_call, "name", None), + "arguments": getattr(function_call, "arguments", {}), + } + self.tool_calls.append(record) + tool_calls_this_turn.append(record) + tool_output = await self._ctx.call_tool(function_call) + normalized = _normalize_message(tool_output) + tool_outputs.append(normalized) + self.messages.append(normalized) + + response_input = tool_outputs + + return ScenarioChatTurnResult(answer=self.last_answer, tool_calls=tool_calls_this_turn) + + def to_state(self) -> dict[str, Any]: + """Serialize session state for persistence across HTTP requests. + + Returns a JSON-serializable dict that can be stored and later + passed to :meth:`from_state` to restore the session. + """ + return { + "messages": self.messages, + "tool_calls": self.tool_calls, + "last_answer": self.last_answer, + "trace_id": self.trace_id, + "model": self.model, + "api": self.api, + "max_steps": self.max_steps, + "system_prompt": self.system_prompt, + "completion_kwargs": self.completion_kwargs, + "scenario_prompt": self._scenario_prompt, + "previous_response_id": self._previous_response_id, + } + + @classmethod + def from_state( + cls, + state: dict[str, Any], + *, + client: AsyncOpenAI, + ctx: Any, + ) -> "ScenarioChatSession": + """Restore a session from a previously serialized state. + + The caller must provide a live ``client`` and ``ctx`` + (an :class:`~hud.eval.context.EvalContext` for the same + environment/scenario/trace). Lifecycle of ``ctx`` is managed + externally -- calling :meth:`finish` will submit via ``ctx`` + but will *not* exit an eval context manager. + + Args: + state: Dict produced by :meth:`to_state`. + client: An ``AsyncOpenAI`` (or compatible) client. + ctx: A live EvalContext bound to the same trace. + """ + session = object.__new__(cls) + session.client = client + session.model = state["model"] + session.task = None # type: ignore[assignment] + session.api = state["api"] + session.max_steps = state["max_steps"] + session.system_prompt = state.get("system_prompt") + session.trace = True + session.api_key = None + session.completion_kwargs = state.get("completion_kwargs", {}) + + session.messages = list(state["messages"]) + session.tool_calls = list(state["tool_calls"]) + session.last_answer = state["last_answer"] + session.trace_id = state["trace_id"] + + session._ctx = ctx + session._eval_cm = None + session._entered = True + session._finished = False + session._previous_response_id = state.get("previous_response_id") + session._scenario_prompt = state.get("scenario_prompt", "") + + return session + + +def _resolve_task( + *, + task: Task | None, + env: Environment | None, + scenario: str | None, + args: dict[str, Any] | None, +) -> Task: + if task is not None and (env is not None or scenario is not None or args is not None): + raise ValueError("Provide either task OR (env, scenario, args), not both") + + if task is None: + if env is None or scenario is None: + raise ValueError("When task is not provided, both env and scenario are required") + task = env(scenario, **(args or {})) + + if not task.scenario: + raise ValueError("Task must include a scenario to run scenario chat") + + return task + + +def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any] | str: + try: + parsed = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments + return parsed if isinstance(parsed, dict) else raw_arguments + except (json.JSONDecodeError, TypeError): + return raw_arguments + + +def _normalize_message(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + if hasattr(value, "model_dump"): + dumped = value.model_dump() + if isinstance(dumped, dict): + return dumped + if hasattr(value, "__dict__"): + return vars(value) + return {"content": str(value)} + + +async def run_scenario_chat( + *, + client: AsyncOpenAI, + model: str, + task: Task | None = None, + env: Environment | None = None, + scenario: str | None = None, + args: dict[str, Any] | None = None, + api: ScenarioChatApi = "auto", + max_steps: int = 20, + system_prompt: str | None = None, + trace: bool = True, + api_key: str | None = None, + completion_kwargs: dict[str, Any] | None = None, +) -> ScenarioChatResult: + """Run a scenario as a single chat session and return the evaluated result. + + This is the simplest way to export a scenario as a chat -- one call, done. + + Example: + ```python + result = await hud.run_scenario_chat( + client=client, model="gpt-4o", + env=env, scenario="my-scenario", + args={"arg": "value"}, + ) + print(result.reward, result.answer) + ``` + """ + eval_task = _resolve_task(task=task, env=env, scenario=scenario, args=args) + + async with run_scenario_chat_interactive( + client=client, + model=model, + task=eval_task, + api=api, + max_steps=max_steps, + system_prompt=system_prompt, + trace=trace, + api_key=api_key, + completion_kwargs=completion_kwargs, + ) as chat: + turn = await chat.send("Begin.") + return await chat.finish(turn.answer) + + +def run_scenario_chat_interactive( + *, + client: AsyncOpenAI, + model: str, + task: Task | None = None, + env: Environment | None = None, + scenario: str | None = None, + args: dict[str, Any] | None = None, + api: ScenarioChatApi = "auto", + max_steps: int = 20, + system_prompt: str | None = None, + trace: bool = True, + api_key: str | None = None, + completion_kwargs: dict[str, Any] | None = None, +) -> ScenarioChatSession: + """Create an interactive scenario chat session. + + Example: + ```python + async with hud.run_scenario_chat_interactive(...) as chat: + await chat.send("first turn") + await chat.send("follow-up") + result = await chat.finish() + ``` + """ + if max_steps <= 0: + raise ValueError("max_steps must be >= 1") + if api not in ("auto", "chat_completions", "responses"): + raise ValueError("api must be one of: auto, chat_completions, responses") + + eval_task = _resolve_task(task=task, env=env, scenario=scenario, args=args) + return ScenarioChatSession( + client=client, + model=model, + task=eval_task, + api=api, + max_steps=max_steps, + system_prompt=system_prompt, + trace=trace, + api_key=api_key, + completion_kwargs=completion_kwargs, + ) + + diff --git a/hud/settings.py b/hud/settings.py index 200b6d26a..ab7c61411 100644 --- a/hud/settings.py +++ b/hud/settings.py @@ -148,6 +148,15 @@ def settings_customise_sources( validation_alias="HUD_TELEMETRY_ENABLED", ) + canonical_hub_mcp_tool_spans_enabled: bool = Field( + default=False, + description=( + "Emit canonical mcp.tool_call spans from EvalContext for hub-connected " + "environments." + ), + validation_alias="HUD_CANONICAL_HUB_MCP_TOOL_SPANS_ENABLED", + ) + hud_logging: bool = Field( default=True, description="Enable fancy logging for the HUD SDK", diff --git a/hud/tests/test_agent_server.py b/hud/tests/test_agent_server.py new file mode 100644 index 000000000..0fa0677b5 --- /dev/null +++ b/hud/tests/test_agent_server.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from types import SimpleNamespace +from typing import Any + +from fastapi.testclient import TestClient + +from hud.agent_server import _build_app + + +class _FakeEnv: + def __init__(self) -> None: + self.entered = False + self.exited = False + + async def __aenter__(self) -> _FakeEnv: + self.entered = True + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + _ = (exc_type, exc, tb) + self.exited = True + + async def list_scenarios(self) -> list[Any]: + arg = SimpleNamespace( + name="ticket_id", + type="string", + required=True, + description="Ticket identifier", + default=None, + ) + scenario = SimpleNamespace( + name="demo:investigate", + short_name="investigate", + description="Investigate an issue", + required_args=["ticket_id"], + arguments=[arg], + ) + return [scenario] + + +class _FakeChat: + def __init__(self) -> None: + self.trace_id = "trace-test-123" + self.messages: list[str] = [] + self.last_answer = "" + + async def send(self, message: str) -> Any: + self.messages.append(message) + self.last_answer = f"echo:{message}" + return SimpleNamespace(answer=self.last_answer, tool_calls=[]) + + async def send_stream(self, message: str) -> AsyncIterator[Any]: + self.messages.append(message) + self.last_answer = f"echo:{message}" + yield SimpleNamespace(type="text_delta", content=self.last_answer) + + async def finish(self, answer: str | None = None) -> Any: + final = answer if answer is not None else self.last_answer + return SimpleNamespace(answer=final, reward=0.75, trace_id=self.trace_id) + + +def _fake_runner_factory(calls: list[dict[str, Any]]) -> Any: + @asynccontextmanager + async def _runner( + *, + client: Any, + model: str, + env: Any, + scenario: str, + args: dict[str, Any], + max_steps: int, + ) -> AsyncIterator[_FakeChat]: + _ = (client, env) + calls.append( + { + "model": model, + "scenario": scenario, + "args": args, + "max_steps": max_steps, + } + ) + chat = _FakeChat() + yield chat + + return _runner + + +def _make_client(monkeypatch: Any) -> tuple[TestClient, list[dict[str, Any]], _FakeEnv]: + calls: list[dict[str, Any]] = [] + monkeypatch.setattr("hud.scenario_chat.run_scenario_chat_interactive", _fake_runner_factory(calls)) + env = _FakeEnv() + app = _build_app( + env=env, + client=SimpleNamespace(), + model="gpt-4o", + api_key=None, + session_ttl=120, + ) + return TestClient(app), calls, env + + +def test_first_turn_requires_scenario_and_scenario_args(monkeypatch: Any) -> None: + client, calls, _ = _make_client(monkeypatch) + + with client: + missing_args = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Begin"}], + "scenario": "investigate", + }, + ) + assert missing_args.status_code == 400 + assert "scenario_args is required" in missing_args.text + + ok = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Begin"}], + "scenario": "investigate", + "scenario_args": {"ticket_id": "T-1"}, + }, + ) + assert ok.status_code == 200 + payload = ok.json() + assert payload["hud"]["thread_id"] == payload["hud"]["session_id"] + assert payload["hud"]["conversation_id"] == payload["hud"]["session_id"] + assert calls[0]["args"] == {"ticket_id": "T-1"} + + +def test_followup_allows_thread_id_body_alias(monkeypatch: Any) -> None: + client, _, _ = _make_client(monkeypatch) + + with client: + first = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Begin"}], + "scenario": "investigate", + "scenario_args": {"ticket_id": "T-2"}, + }, + ) + session_id = first.json()["hud"]["session_id"] + + follow_up = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "What happened?"}], + "thread_id": session_id, + }, + ) + assert follow_up.status_code == 200 + assert follow_up.json()["choices"][0]["message"]["content"] == "echo:What happened?" + + +def test_lifecycle_tool_calls_and_mcp_surface(monkeypatch: Any) -> None: + client, _, _ = _make_client(monkeypatch) + + with client: + tool_defs = client.get("/v1/lifecycle-tools") + assert tool_defs.status_code == 200 + names = {tool["name"] for tool in tool_defs.json()["tools"]} + assert {"scenario_list", "scenario_start", "scenario_send", "scenario_finish"} <= names + + mcp_defs = client.get("/mcp/tools") + assert mcp_defs.status_code == 200 + mcp_names = {tool["name"] for tool in mcp_defs.json()["tools"]} + assert names == mcp_names + + start = client.post( + "/v1/lifecycle-tools/call", + json={ + "name": "scenario_start", + "arguments": { + "scenario": "investigate", + "scenario_args": {"ticket_id": "T-3"}, + "message": "Begin with context", + }, + }, + ) + assert start.status_code == 200 + hud = start.json()["hud"] + session_id = hud["session_id"] + + send = client.post( + "/mcp/tools/call", + json={ + "name": "scenario_send", + "arguments": {"thread_id": session_id, "message": "Follow up"}, + }, + ) + assert send.status_code == 200 + assert send.json()["answer"] == "echo:Follow up" + + finish = client.post( + "/mcp/tools/call", + json={ + "name": "scenario_finish", + "arguments": {"conversation_id": session_id, "answer": "Final answer"}, + }, + ) + assert finish.status_code == 200 + assert finish.json()["answer"] == "Final answer" + assert finish.json()["reward"] == 0.75 + + +def test_legacy_finish_endpoint_and_session_listing(monkeypatch: Any) -> None: + client, _, env = _make_client(monkeypatch) + + with client: + first = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Begin"}], + "scenario": "investigate", + "scenario_args": {"ticket_id": "T-4"}, + }, + ) + session_id = first.json()["hud"]["session_id"] + + sessions = client.get("/v1/sessions") + assert sessions.status_code == 200 + assert sessions.json()["sessions"][0]["thread_id"] == session_id + assert sessions.json()["sessions"][0]["conversation_id"] == session_id + + done = client.post(f"/v1/sessions/{session_id}/finish") + assert done.status_code == 200 + assert done.json()["session_id"] == session_id + assert done.json()["thread_id"] == session_id + assert done.json()["conversation_id"] == session_id + + assert env.entered is True + assert env.exited is True diff --git a/hud/tests/test_init_module.py b/hud/tests/test_init_module.py index 607dbfae3..0b123adf3 100644 --- a/hud/tests/test_init_module.py +++ b/hud/tests/test_init_module.py @@ -21,10 +21,18 @@ def test_all_exports(self): import hud expected = [ + "ChatEvent", "Environment", "EvalContext", + "ScenarioArg", + "ScenarioChatResult", + "ScenarioChatSession", + "ScenarioChatTurnResult", + "ScenarioInfo", "eval", "instrument", + "run_scenario_chat", + "run_scenario_chat_interactive", "trace", # Deprecated alias for eval ] diff --git a/hud/tests/test_scenario_chat.py b/hud/tests/test_scenario_chat.py new file mode 100644 index 000000000..dc1fb36ce --- /dev/null +++ b/hud/tests/test_scenario_chat.py @@ -0,0 +1,494 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from types import SimpleNamespace +from typing import Any + +import pytest + +from hud.scenario_chat import ChatEvent, ScenarioChatSession, run_scenario_chat_interactive +from hud.tools.types import EvaluationResult + + +class FakeCtx: + def __init__(self) -> None: + self.prompt = "Solve the task" + self.system_prompt: str | None = None + self.trace_id = "trace-123" + self.reward: float | None = None + self.evaluation_result: EvaluationResult | None = None + self.submitted: str | None = None + self.tool_calls_seen: list[Any] = [] + + def as_openai_chat_tools(self) -> list[dict[str, Any]]: + return [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup data", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + def as_openai_responses_tools(self) -> list[dict[str, Any]]: + return [ + { + "type": "function", + "name": "lookup", + "description": "Lookup data", + "parameters": {"type": "object", "properties": {}}, + } + ] + + async def call_tool(self, call: Any, /, **kwargs: Any) -> dict[str, Any]: + _ = kwargs + self.tool_calls_seen.append(call) + + if isinstance(call, dict): + call_id = call.get("id", "unknown") + return {"role": "tool", "tool_call_id": call_id, "content": "ok"} + + if getattr(call, "type", None) == "function_call": + return {"type": "function_call_output", "call_id": call.id, "output": "ok"} + + return {"role": "tool", "tool_call_id": call.id, "content": "ok"} + + async def submit(self, answer: str) -> None: + self.submitted = answer + + +class _SequentialApi: + def __init__(self, responses: list[Any]) -> None: + self._responses = responses + self._idx = 0 + self.calls: list[dict[str, Any]] = [] + + async def create(self, **kwargs: Any) -> Any: + self.calls.append(kwargs) + response = self._responses[self._idx] + self._idx += 1 + return response + + +class _AsyncChunkIter: + """Async iterator over a list of chunk objects (simulates OpenAI streaming).""" + + def __init__(self, chunks: list[Any]) -> None: + self._chunks = chunks + self._idx = 0 + + def __aiter__(self) -> _AsyncChunkIter: + return self + + async def __anext__(self) -> Any: + if self._idx >= len(self._chunks): + raise StopAsyncIteration + chunk = self._chunks[self._idx] + self._idx += 1 + return chunk + + +def _make_stream_chunk( + *, content: str | None = None, tool_calls: list[Any] | None = None +) -> Any: + delta = SimpleNamespace(content=content, tool_calls=tool_calls) + return SimpleNamespace(choices=[SimpleNamespace(delta=delta, finish_reason=None)]) + + +def _build_stream_response( + *, content: str = "", tool_calls: list[dict[str, Any]] | None = None +) -> _AsyncChunkIter: + """Build a mock streaming response (async iterable of chunks).""" + chunks: list[Any] = [] + if content: + for char in content: + chunks.append(_make_stream_chunk(content=char)) + + if tool_calls: + for i, tc in enumerate(tool_calls): + chunks.append(_make_stream_chunk(tool_calls=[ + SimpleNamespace( + index=i, + id=tc["id"], + function=SimpleNamespace(name=tc["name"], arguments=""), + ) + ])) + args_json = json.dumps(tc["arguments"]) + for char in args_json: + chunks.append(_make_stream_chunk(tool_calls=[ + SimpleNamespace( + index=i, + id=None, + function=SimpleNamespace(name=None, arguments=char), + ) + ])) + + return _AsyncChunkIter(chunks) + + +def _chat_response(*, content: str, tool_calls: list[Any] | None = None) -> Any: + message = SimpleNamespace(content=content, tool_calls=tool_calls or []) + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice]) + + +def _make_client(*, chat: list[Any] | None = None, responses: list[Any] | None = None) -> Any: + return SimpleNamespace( + chat=SimpleNamespace(completions=_SequentialApi(chat or [])), + responses=_SequentialApi(responses or []), + ) + + +def _fake_run_eval_factory( + holder: dict[str, Any], *, reward: float, content: str +) -> Callable[..., Any]: + @asynccontextmanager + async def fake_run_eval(*args: Any, **kwargs: Any) -> AsyncIterator[FakeCtx]: + _ = (args, kwargs) + ctx = FakeCtx() + holder["ctx"] = ctx + yield ctx + ctx.reward = reward + ctx.evaluation_result = EvaluationResult(reward=reward, done=True, content=content) + + return fake_run_eval + + +@pytest.mark.asyncio +async def test_run_scenario_chat_interactive_chat_completions( + monkeypatch: pytest.MonkeyPatch, +) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=0.9, content="done") + ) + tool_call = SimpleNamespace( + id="call_1", + function=SimpleNamespace(name="lookup", arguments='{"query":"x"}'), + ) + client = _make_client( + chat=[ + _chat_response(content="", tool_calls=[tool_call]), + _chat_response(content="Analysis complete."), + _chat_response(content="Root cause identified."), + ] + ) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + ) as chat: + first = await chat.send("Begin") + second = await chat.send("Give me the root cause") + result = await chat.finish() + + assert first.answer == "Analysis complete." + assert second.answer == "Root cause identified." + assert result.answer == "Root cause identified." + assert result.reward == 0.9 + assert holder["ctx"].submitted == "Root cause identified." + # Scenario setup prompt should be injected before first user turn. + assert chat.messages[0]["role"] == "user" + assert chat.messages[0]["content"] == "Solve the task" + + +@pytest.mark.asyncio +async def test_run_scenario_chat_interactive_responses(monkeypatch: pytest.MonkeyPatch) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=0.4, content="ok") + ) + + client = _make_client( + responses=[ + SimpleNamespace( + id="resp-1", + output_text="", + output=[ + SimpleNamespace( + type="function_call", + id="fc-1", + name="lookup", + arguments='{"query":"a"}', + ) + ], + ), + SimpleNamespace(id="resp-2", output_text="First response", output=[]), + ] + ) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="responses", + ) as chat: + turn = await chat.send("Analyze this") + result = await chat.finish() + + assert turn.answer == "First response" + assert result.answer == "First response" + assert result.reward == 0.4 + + +# --------------------------------------------------------------------------- +# Streaming tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_stream_text_only(monkeypatch: pytest.MonkeyPatch) -> None: + """send_stream yields text_delta events then turn_complete.""" + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=1.0, content="ok") + ) + + client = _make_client(chat=[ + _build_stream_response(content="Hello world"), + ]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + ) as chat: + events: list[ChatEvent] = [] + async for event in chat.send_stream("Hi"): + events.append(event) + await chat.finish() + + text_deltas = [e for e in events if e.type == "text_delta"] + assert "".join(e.content for e in text_deltas) == "Hello world" + assert events[-1].type == "turn_complete" + assert events[-1].content == "Hello world" + + +@pytest.mark.asyncio +async def test_send_stream_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + """send_stream yields tool_call and tool_result events during tool loop.""" + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=0.8, content="done") + ) + + client = _make_client(chat=[ + _build_stream_response( + tool_calls=[{"id": "call_1", "name": "lookup", "arguments": {"query": "x"}}] + ), + _build_stream_response(content="Found it"), + ]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + ) as chat: + events: list[ChatEvent] = [] + async for event in chat.send_stream("Search for x"): + events.append(event) + await chat.finish() + + event_types = [e.type for e in events] + assert "tool_call" in event_types + assert "tool_result" in event_types + assert "turn_complete" in event_types + + tc_event = next(e for e in events if e.type == "tool_call") + assert tc_event.tool_name == "lookup" + assert tc_event.tool_call_id == "call_1" + + tr_event = next(e for e in events if e.type == "tool_result") + assert tr_event.tool_name == "lookup" + assert tr_event.content == "ok" + + assert events[-1].content == "Found it" + + +# --------------------------------------------------------------------------- +# Trace header propagation tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_trace_headers_forwarded_and_merged_chat(monkeypatch: pytest.MonkeyPatch) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=1.0, content="ok") + ) + client = _make_client(chat=[_chat_response(content="ok")]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + completion_kwargs={"extra_headers": {"x-custom": "1"}}, + ) as chat: + await chat.send("Hi") + await chat.finish() + + sent_headers = client.chat.completions.calls[0]["extra_headers"] + assert sent_headers["Trace-Id"] == "trace-123" + assert sent_headers["x-custom"] == "1" + + +@pytest.mark.asyncio +async def test_trace_headers_forwarded_and_merged_stream(monkeypatch: pytest.MonkeyPatch) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=1.0, content="ok") + ) + client = _make_client(chat=[_build_stream_response(content="hello")]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + completion_kwargs={"extra_headers": {"x-custom": "1"}}, + ) as chat: + async for _ in chat.send_stream("Hi"): + pass + await chat.finish() + + sent_headers = client.chat.completions.calls[0]["extra_headers"] + assert sent_headers["Trace-Id"] == "trace-123" + assert sent_headers["x-custom"] == "1" + assert client.chat.completions.calls[0]["stream"] is True + + +@pytest.mark.asyncio +async def test_trace_headers_forwarded_and_merged_responses( + monkeypatch: pytest.MonkeyPatch, +) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=1.0, content="ok") + ) + client = _make_client( + responses=[SimpleNamespace(id="resp-1", output_text="ok", output=[])] + ) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="responses", + completion_kwargs={"extra_headers": {"x-custom": "1"}}, + ) as chat: + await chat.send("Hi") + await chat.finish() + + sent_headers = client.responses.calls[0]["extra_headers"] + assert sent_headers["Trace-Id"] == "trace-123" + assert sent_headers["x-custom"] == "1" + + +@pytest.mark.asyncio +async def test_no_user_span_when_trace_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=1.0, content="ok") + ) + + queue_span_calls = {"count": 0} + + def fake_queue_span(_span: Any) -> None: + queue_span_calls["count"] += 1 + + monkeypatch.setattr("hud.telemetry.exporter.queue_span", fake_queue_span) + client = _make_client(chat=[_chat_response(content="ok")]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + trace=False, + ) as chat: + await chat.send("Hi") + await chat.finish() + + assert queue_span_calls["count"] == 0 + + +# --------------------------------------------------------------------------- +# Serialization tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_to_state_round_trip(monkeypatch: pytest.MonkeyPatch) -> None: + """to_state captures session state, from_state restores it.""" + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=0.5, content="ok") + ) + + client = _make_client(chat=[ + _chat_response(content="First answer."), + _chat_response(content="Second answer."), + ]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + ) as chat: + await chat.send("Turn one") + state = chat.to_state() + + assert state["model"] == "gpt-4o" + assert state["last_answer"] == "First answer." + assert state["trace_id"] == "trace-123" + assert len(state["messages"]) > 0 + + # Restore and verify + ctx2 = FakeCtx() + client2 = _make_client(chat=[ + _chat_response(content="After restore."), + ]) + restored = ScenarioChatSession.from_state(state, client=client2, ctx=ctx2) + assert restored.model == "gpt-4o" + assert restored.last_answer == "First answer." + assert restored.trace_id == "trace-123" + assert len(restored.messages) == len(state["messages"]) + + turn = await restored.send("Continue") + assert turn.answer == "After restore." + + +@pytest.mark.asyncio +async def test_from_state_finish(monkeypatch: pytest.MonkeyPatch) -> None: + """A restored session can call finish() to submit and get results.""" + holder: dict[str, Any] = {} + monkeypatch.setattr( + "hud.scenario_chat.run_eval", _fake_run_eval_factory(holder, reward=0.7, content="ok") + ) + + client = _make_client(chat=[_chat_response(content="answer")]) + + async with run_scenario_chat_interactive( + client=client, + model="gpt-4o", + task=SimpleNamespace(scenario="demo"), + api="chat_completions", + ) as chat: + await chat.send("Go") + state = chat.to_state() + + ctx2 = FakeCtx() + restored = ScenarioChatSession.from_state(state, client=_make_client(), ctx=ctx2) + result = await restored.finish("final") + assert result.answer == "final" + assert ctx2.submitted == "final"