diff --git a/docs/changelog.mdx b/docs/changelog.mdx
index e99ec3366..7ac4fb8a8 100644
--- a/docs/changelog.mdx
+++ b/docs/changelog.mdx
@@ -25,7 +25,7 @@ description: "Product updates and release notes for HUD SDK and Platform."
- **`hud sync env`** — sync local environment configs with collision detection (replaces `hud link`).
- **`hud eval` accepts Python files** — run evaluations directly from `.py` files and directories containing `Task` objects.
- **Chat class** — manage multi-turn agent conversations from a single SDK abstraction.
-- **GPT-5 support** — `ResponseAgent` defaults to `gpt-5`, with ToolSearch tool support.
+- **GPT-5 support** — auto-response classification defaults to `gpt-5`, with ToolSearch tool support.
- **Citations** — citation support for Claude, Gemini, and OpenAI responses in chat and agent traces.
### Platform
diff --git a/docs/reference/agents.mdx b/docs/reference/agents.mdx
index f21c276f9..8ca828b2d 100644
--- a/docs/reference/agents.mdx
+++ b/docs/reference/agents.mdx
@@ -42,7 +42,7 @@ Abstract base class for all MCP-enabled agents. Handles the agent loop, MCP clie
|-----------|------|-------------|---------|
| `mcp_client` | `AgentMCPClient` | MCP client for server connections | `None` |
| `auto_trace` | `bool` | Enable automatic tracing spans | `True` |
-| `auto_respond` | `bool` | Use ResponseAgent to decide when to stop/continue | `False` |
+| `auto_respond` | `bool` | Use response automation to decide when to stop/continue | `False` |
| `verbose` | `bool` | Verbose console logs for development | `False` |
**Base Config** (shared by all agents):
@@ -63,9 +63,6 @@ async def run(ctx: EvalContext, max_steps: int = 10) -> Trace:
async def call_tools(tool_call: MCPToolCall | list[MCPToolCall]) -> list[MCPToolResult]:
"""Execute tool calls through MCP client."""
-
-def get_available_tools() -> list[types.Tool]:
- """Get filtered list of available tools."""
```
## Pre-built Agents
@@ -251,7 +248,7 @@ result = await agent.run(task, max_steps=20)
### Auto-Respond Mode
-When `auto_respond=True`, the agent uses a ResponseAgent to decide whether to continue or stop after each model response:
+When `auto_respond=True`, the agent uses response automation to decide whether to continue or stop after each model response:
```python
agent = ClaudeAgent.create(
diff --git a/docs/reference/cli/eval.mdx b/docs/reference/cli/eval.mdx
index 13903c044..d79f2596b 100644
--- a/docs/reference/cli/eval.mdx
+++ b/docs/reference/cli/eval.mdx
@@ -79,7 +79,7 @@ hud eval [SOURCE] [AGENT] [OPTIONS]
- Use ResponseAgent to decide when to stop/continue. Default: True for `--full`.
+ Use response automation to decide when to stop/continue. Default: True for `--full`.
### Taskset Association
diff --git a/docs/reference/types.mdx b/docs/reference/types.mdx
index f3d8e091d..bbd5bfad8 100644
--- a/docs/reference/types.mdx
+++ b/docs/reference/types.mdx
@@ -111,12 +111,12 @@ print(result.reward, result.done)
| `trace` | `list[TraceStep]` | Execution trace steps |
| `messages` | `list[Any]` | Final conversation state |
-## InferenceResult
+## AgentResponse
Returned by agent `get_response()` methods. Represents the result of a single LLM inference call.
```python
-from hud.types import InferenceResult
+from hud.types import AgentResponse
```
| Field | Type | Description |
@@ -129,8 +129,6 @@ from hud.types import InferenceResult
| `info` | `dict[str, Any]` | Provider-specific metadata |
| `isError` | `bool` | Error flag |
-> **Note:** `AgentResponse` is available as a backwards-compatible alias for `InferenceResult`.
-
## AgentType
Enum of supported agent types.
diff --git a/hud/agents/__init__.py b/hud/agents/__init__.py
index 27ca0b327..b17f59bb5 100644
--- a/hud/agents/__init__.py
+++ b/hud/agents/__init__.py
@@ -1,95 +1,15 @@
from __future__ import annotations
-import sys
-from types import ModuleType
-from typing import Any
-
-from .base import CategorizedTools, MCPAgent
+from .base import MCPAgent
from .claude import ClaudeAgent
+from .gateway import create_agent
from .openai import OpenAIAgent
from .openai_compatible import OpenAIChatAgent
__all__ = [
- "CategorizedTools",
"ClaudeAgent",
"MCPAgent",
"OpenAIAgent",
"OpenAIChatAgent",
"create_agent",
]
-
-
-def _install_openai_chat_compat_module() -> None:
- module_name = f"{__name__}.openai_chat"
- if module_name in sys.modules:
- return
-
- module: Any = ModuleType(module_name, "Compatibility module for OpenAIChatAgent.")
- module.OpenAIChatAgent = OpenAIChatAgent
- module.__all__ = ["OpenAIChatAgent"]
- sys.modules[module_name] = module
-
-
-_install_openai_chat_compat_module()
-
-
-def create_agent(model: str, **kwargs: Any) -> MCPAgent:
- """Create an agent for a gateway model.
-
- This routes ALL requests through the HUD gateway. For direct API access
- (using your own API keys), use the agent classes directly.
-
- Args:
- model: Model name (e.g., "gpt-5.4", "claude-sonnet-4-6").
- **kwargs: Additional params passed to agent.create().
-
- Returns:
- Configured MCPAgent instance with gateway routing.
-
- Example:
- ```python
- # Gateway routing (recommended)
- agent = create_agent("gpt-5.4")
- agent = create_agent("claude-sonnet-4-6", temperature=0.7)
-
- # Direct API access (use agent classes)
- from hud.agents.claude import ClaudeAgent
-
- agent = ClaudeAgent.create(model="claude-sonnet-4-6")
- ```
- """
- from hud.agents.gateway import build_gateway_client
- from hud.agents.resolver import resolve_cls
-
- # Resolve class and gateway info
- agent_cls, gateway_info = resolve_cls(model)
-
- # Get model name from gateway info or use input
- model_id = model
- if gateway_info:
- model_id = gateway_info.get("model_name") or model
-
- # Determine provider: from gateway info, or infer from agent class
- if gateway_info:
- provider = gateway_info["provider"]["name"]
- else:
- provider = "openai"
- if agent_cls.__name__ == "ClaudeAgent":
- provider = "anthropic"
- elif agent_cls.__name__ == "GeminiAgent":
- provider = "gemini"
-
- client = build_gateway_client(provider)
-
- # Set up kwargs
- kwargs.setdefault("model", model_id)
-
- # Use correct client key based on agent type
- if agent_cls == OpenAIChatAgent:
- kwargs.setdefault("openai_client", client)
- else:
- # Claude and other agents use model_client and validate_api_key
- kwargs.setdefault("model_client", client)
- kwargs.setdefault("validate_api_key", False)
-
- return agent_cls.create(**kwargs)
diff --git a/hud/agents/base.py b/hud/agents/base.py
index 9e2581c1f..75fb7345f 100644
--- a/hud/agents/base.py
+++ b/hud/agents/base.py
@@ -3,543 +3,188 @@
from __future__ import annotations
import asyncio
-import json
import logging
-import re
from abc import ABC, abstractmethod
-from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, ClassVar, Literal
+from dataclasses import dataclass
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
-import mcp.types as types
-
-from hud.tools.types import Citation
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult, Trace
-from hud.utils.hud_console import HUDConsole
-
-from .types import BaseCreateParams
+from hud.agents.misc import auto_respond
+from hud.types import AgentResponse, Trace
if TYPE_CHECKING:
- from hud.environment import Environment
- from hud.eval.context import EvalContext
+ import mcp.types as types
+ from hud.agents.tools import AgentTools
+ from hud.agents.tools.base import CallTool, ToolClient
+ from hud.agents.types import AgentConfig
+ProviderMessageT = TypeVar("ProviderMessageT")
logger = logging.getLogger(__name__)
-@dataclass
-class CategorizedTools:
- """Result of filtering tools for model-facing schemas."""
-
- generic: list[types.Tool] = field(default_factory=list)
- """MCP tools exposed through generic function calling."""
+@dataclass(frozen=True)
+class AgentContext:
+ """Prompt messages plus optional MCP tool access for one agent run."""
- skipped: list[tuple[types.Tool, str]] = field(default_factory=list)
- """Tools intentionally hidden from generic function calling."""
+ messages: list[types.PromptMessage]
+ tool_client: ToolClient | None = None
-class MCPAgent(ABC):
+class MCPAgent(ABC, Generic[ProviderMessageT]):
"""
- Base class for MCP-enabled agents.
-
- Agents interact with MCP servers through an EvalContext:
- - run(ctx): Main entry point - takes EvalContext from hud.eval()
- - ctx.call_tool(): Used internally for all tool execution
- - ctx.submit(): Called automatically with agent's final response
-
- Subclasses implement provider-specific formatting and response fetching
- by overriding: `get_system_messages`, `get_response`, `format_blocks`,
- and `format_tool_results`.
- """
-
- metadata: ClassVar[dict[str, Any] | None] = None
- required_tools: ClassVar[list[str]] = [] # Tools that must be available
- config_cls: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig
-
- @classmethod
- @abstractmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for this agent.
-
- Subclasses must implement this to return their corresponding AgentType enum value.
- This is used for provider-specific configuration and routing.
-
- Returns:
- AgentType enum value for this agent
- """
- raise NotImplementedError
-
- def categorize_tools(self, tools: list[types.Tool] | None = None) -> CategorizedTools:
- """Return the MCP tools that should be exposed as generic function tools."""
- if tools is None:
- tools = self.get_available_tools()
-
- return CategorizedTools(generic=list(tools))
+ Base class for agents that interact with HUD MCP-backed environments.
- def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> None:
- if params is None:
- import warnings
+ Agent instances are intended to be run-scoped: create a fresh agent for each
+ independent evaluation or task run. Provider implementations may keep
+ conversation IDs, continuation cursors, and prepared tool state on the
+ instance during a run.
- warnings.warn(
- f"Passing kwargs to {self.__class__.__name__}() is deprecated. "
- f"Use {self.__class__.__name__}.create(...) instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- CreateParams = type(
- f"{self.config_cls.__name__}CreateParams",
- (BaseCreateParams, self.config_cls),
- {"__module__": self.config_cls.__module__},
- )
- params = CreateParams(**kwargs)
-
- config_kwargs = {
- k: getattr(params, k) for k in self.config_cls.model_fields if hasattr(params, k)
- }
- self.config = self.config_cls(**config_kwargs)
+ Agents interact with environments through per-run tools and tool handlers supplied
+ by the caller.
- # Store execution context (EvalContext/Environment); agents use ctx.call_tool().
- self.ctx: EvalContext | Environment | None = params.ctx
-
- self.model_name: str = getattr(params, "model_name", "MCPAgent")
- self.model: str = getattr(params, "model", None) or "unknown"
- self.auto_respond = params.auto_respond
+ Subclasses implement provider-specific message formatting, response fetching,
+ and tool result rendering.
+ """
- self.console = HUDConsole(logger=logger)
+ def __init__(self, config: AgentConfig) -> None:
+ self.config = config
- if params.verbose:
- self.console.set_verbose(True)
+ self.model_name: str = self.config.model_name
+ self.model: str = self.config.model
self.system_prompt = self.config.system_prompt
- self._available_tools: list[types.Tool] | None = None
- self._categorized_tools: CategorizedTools = CategorizedTools()
- self._initialized: bool = False
-
- @classmethod
- def create(cls, **kwargs: Any) -> MCPAgent:
- """
- Factory method to create an agent with typed parameters.
- """
- CreateParams = type(
- f"{cls.config_cls.__name__}CreateParams",
- (BaseCreateParams, cls.config_cls),
- {"__module__": cls.config_cls.__module__},
- )
- return cls(params=CreateParams(**kwargs))
-
- async def _initialize_from_ctx(self, ctx: EvalContext) -> None:
- """Initialize agent from EvalContext - discovers tools and sets up state.
-
- The agent uses ctx.call_tool() directly
- for tool execution (no EnvironmentClient wrapper needed).
- """
- from hud.eval.context import EvalContext
-
- if not isinstance(ctx, EvalContext):
- raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}")
-
- # Refresh tools from connections, then get filtered list for agent
- await ctx.list_tools()
- self._available_tools = ctx.as_tools()
-
- # Validate required tools are present
- available_tool_names = {t.name for t in self._available_tools}
- missing_tools = [tool for tool in self.required_tools if tool not in available_tool_names]
- if missing_tools:
- raise ValueError(
- f"Required tools are missing: {missing_tools}. "
- f"Available tools: {sorted(available_tool_names)}"
- )
-
- self._categorized_tools = self.categorize_tools()
-
- # Show tool discovery table (visible at INFO level)
- self.console.format_tool_discovery(
- tools=self._available_tools,
- skipped=self._categorized_tools.skipped,
- )
+ self.enable_citations: bool = False
- for tool, reason in self._categorized_tools.skipped:
- logger.debug("Skipping tool %s: %s", tool.name, reason)
+ self.auto_respond: bool = config.auto_respond
- # Call hook for subclass-specific initialization (e.g., tool format conversion)
- self._on_tools_ready()
-
- self._initialized = True
-
- def _on_tools_ready(self) -> None:
- """Hook called after tools are discovered and validated.
-
- Subclasses can override this to perform provider-specific setup,
- such as converting MCP tools to the provider's format.
+ @classmethod
+ def create(cls, **kwargs: object) -> MCPAgent[ProviderMessageT]:
+ raise NotImplementedError(f"{cls.__name__}.create() must be implemented by subclasses")
- Called by _initialize_from_ctx() after _available_tools is populated.
- """
- return # Default no-op - subclasses override for provider-specific setup
+ @cached_property
+ @abstractmethod
+ def tools(self) -> AgentTools[Any, Any]:
+ """Provider-specific tool container used by the shared run loop."""
+ raise NotImplementedError
async def run(
self,
- ctx: EvalContext,
+ ctx: AgentContext,
*,
max_steps: int = 10,
) -> Trace:
"""
- Run the agent on the given evaluation context.
-
- The agent uses ctx.prompt as the task and ctx.call_tool() for tool execution.
- Automatically calls ctx.submit() with the final answer.
+ Run the agent loop with prepared messages and optional tools.
Args:
- ctx: EvalContext from hud.eval() - contains prompt and tools
+ ctx: Prompt messages and optional environment client
max_steps: Maximum number of agent steps (-1 for infinite)
- Returns:
- Trace with done, content, isError fields
-
- Example:
- ```python
- async with hud.eval(task) as ctx:
- agent = ClaudeAgent.create()
- await agent.run(ctx)
- # ctx.reward is set by the scenario's evaluate phase
- ```
- """
- from hud.eval.context import EvalContext
-
- if not isinstance(ctx, EvalContext):
- raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}")
-
- if not ctx.prompt:
- if ctx.has_scenario:
- # Scenario was specified but prompt is still empty
- # (e.g., scenario returned empty string, or edge case not caught in scenarios.py)
- scenario = ctx._task.scenario if ctx._task else "unknown"
- raise ValueError(
- f"ctx.prompt is not set.\n\n"
- f"Scenario '{scenario}' was specified but returned an empty prompt.\n"
- f"Check that the scenario's setup function returns a non-empty string."
- )
- else:
- # No scenario specified at all
- raise ValueError(
- "ctx.prompt is not set.\n\n"
- "No scenario was specified in your task file.\n"
- "Add a 'scenario' field to your task so scenario setup can produce a prompt."
- )
-
- # Store context for tool calls
- self.ctx = ctx
-
- # Initialize tools from context
- if not self._initialized:
- await self._initialize_from_ctx(ctx)
-
- try:
- # Build initial context
- conversation: list[dict[str, str]] | None = getattr(ctx, "conversation", None)
-
- if conversation:
- # Multi-turn: build alternating role messages
- initial_messages = await self._build_conversation_messages(conversation)
- else:
- # Single-turn: single user message from prompt
- initial_messages = await self.format_message(ctx.prompt)
-
- result = await self._run_context(initial_messages, max_steps=max_steps)
-
- # Propagate error state to context for platform visibility
- if result.isError and hasattr(ctx, "error"):
- error_msg = result.info.get("error") if result.info else result.content
- ctx.error = Exception(str(error_msg)) if error_msg else Exception("Agent error")
-
- # Submit final answer to context (only if scenario is running)
- if result.content and ctx.has_scenario:
- if result.citations:
- await ctx.submit(
- {
- "content": result.content,
- "citations": result.citations,
- }
- )
- else:
- await ctx.submit(result.content)
-
- return result
-
- except Exception as e:
- logger.exception("Error while running agent:")
- # Propagate error to context for platform visibility
- if hasattr(ctx, "error"):
- ctx.error = e
- return Trace(
- reward=0.0,
- done=True,
- content=f"Agent failed with error: {e}",
- isError=True,
- info={"error": str(e)},
- )
- finally:
- # Cleanup auto-created resources
- await self._cleanup()
-
- def _map_role(self, role: str) -> str:
- """Map a canonical role name to the provider-specific role.
-
- Override in subclasses where the provider uses different role names.
- Default passes through (works for OpenAI and Claude which use "assistant").
- """
- return role
-
- async def _build_conversation_messages(self, conversation: list[dict[str, str]]) -> list[Any]:
- """Build provider-formatted messages from a conversation history."""
- result: list[Any] = []
- for msg in conversation:
- role = self._map_role(msg.get("role", "user"))
- content = msg.get("content", "")
- formatted = await self.format_message(content)
- for fm in formatted:
- if isinstance(fm, dict):
- fm["role"] = role
- elif hasattr(fm, "role"):
- fm.role = role # type: ignore[attr-defined]
- result.extend(formatted)
- return result
-
- async def _run_context(self, initial_messages: list[Any], *, max_steps: int = 10) -> Trace:
- """
- Run the agent with pre-built messages. This is the core agent loop.
-
- Args:
- initial_messages: Provider-formatted messages (from format_message or conversation)
- max_steps: Maximum number of steps (-1 for infinite)
-
Returns:
Trace with reward, done, content fields and trace steps
"""
- final_response: InferenceResult | None = None
- error = None
-
- messages: list[Any] = []
+ tool_handler: CallTool | None = None
+ if ctx.tool_client is not None:
+ self.tools.prepare(
+ model=self.model,
+ tools=ctx.tool_client.tools,
+ hosted_tools=self.config.hosted_tools,
+ tool_metadata=ctx.tool_client.tool_metadata,
+ )
+ tool_handler = ctx.tool_client.tool_handler
+ messages: list[ProviderMessageT] = []
try:
- messages = await self.get_system_messages()
- messages.extend(initial_messages)
- self.console.debug(f"Messages: {messages}")
+ messages = await self.format_messages(ctx.messages)
+ logger.debug("Messages: %s", messages)
step_count = 0
while max_steps == -1 or step_count < max_steps:
step_count += 1
if max_steps == -1:
- self.console.debug(f"Step {step_count} (unlimited)")
+ logger.debug("Step %s (unlimited)", step_count)
else:
- self.console.debug(f"Step {step_count}/{max_steps}")
+ logger.debug("Step %s/%s", step_count, max_steps)
try:
# 1. Get model response
response = await self.get_response(messages)
- self.console.debug(f"Agent:\n{response}")
+ logger.debug("Agent:\n%s", response)
- # Check if we should stop
if response.done or not response.tool_calls:
- # Use auto_respond to decide whether to stop
- decision: Literal["STOP", "CONTINUE"] = "STOP"
- if self.auto_respond and response.content:
- try:
- from hud.agents.misc import ResponseAgent
-
- response_agent = ResponseAgent()
- decision = await response_agent.determine_response(response.content)
- except Exception as e:
- self.console.warning_log(f"Auto-respond failed: {e}")
- if decision == "STOP":
- if (
- getattr(self.ctx, "scenario_enable_citations", False)
- and not response.citations
- ):
- recovered = self._recover_citations_from_content(response)
- if recovered:
- self.console.info_log(
- "Recovered citations from JSON answer payload"
- )
- else:
- self.console.warning_log(
- "Citations required by scenario but missing in final response" # noqa: E501
- )
- self.console.debug("Stopping execution")
- final_response = response
- break
- else:
- self.console.debug("Continuing execution")
- messages.extend(await self.format_message(decision))
+ if follow_up := await auto_respond(
+ response.content,
+ enabled=self.auto_respond,
+ ):
+ logger.debug("Continuing execution")
+ messages.extend(await self.format_messages([follow_up]))
continue
- # 2. Execute tools
- tool_calls = response.tool_calls
- tool_results = await self.call_tools(tool_calls)
-
- # 3. Format tool results and add to messages
- tool_messages = await self.format_tool_results(tool_calls, tool_results)
- messages.extend(tool_messages)
-
- if logger.isEnabledFor(logging.INFO):
- self.console.format_step(
- step=step_count,
- max_steps=max_steps,
- tool_calls=tool_calls,
- tool_results=tool_results,
+ logger.debug("Stopping execution")
+ return Trace(
+ done=True,
+ messages=messages,
+ content=response.content,
+ isError=response.isError,
+ citations=response.citations,
)
+ # 2. Execute tools
+ tool_messages = await self.tools.execute(
+ tool_handler,
+ response.tool_calls,
+ )
+
+ messages.extend(cast("list[ProviderMessageT]", tool_messages))
+
except Exception as e:
- self.console.error_log(f"Step failed: {e}")
- error = str(e)
- break
+ logger.exception("Step failed")
+ return Trace(
+ done=True,
+ messages=messages,
+ content=str(e),
+ isError=True,
+ info={"error": str(e)},
+ )
except KeyboardInterrupt:
- self.console.warning_log("Agent execution interrupted by user")
- error = "Interrupted by user"
+ logger.warning("Agent execution interrupted by user")
+ return Trace(
+ done=True,
+ messages=messages,
+ content="Interrupted by user",
+ isError=True,
+ info={"error": "Interrupted by user"},
+ )
except asyncio.CancelledError:
- self.console.warning_log("Agent execution cancelled")
- error = "Cancelled"
+ logger.warning("Agent execution cancelled")
+ return Trace(
+ done=True,
+ messages=messages,
+ content="Cancelled",
+ isError=True,
+ info={"error": "Cancelled"},
+ )
except Exception as e:
- self.console.error_log(f"Unexpected error: {e}")
- error = str(e)
-
- # Build result
- if error is not None or (
- final_response and hasattr(final_response, "isError") and final_response.isError
- ):
- is_error = True
- else:
- is_error = False
-
- # Use ctx.reward if already set (e.g., from scenario evaluate), otherwise 0.0
- reward = 0.0
- if self.ctx is not None:
- ctx_reward = getattr(self.ctx, "reward", None)
- if ctx_reward is not None:
- reward = ctx_reward
+ logger.exception("Unexpected error")
+ return Trace(
+ done=True,
+ messages=messages,
+ content=str(e),
+ isError=True,
+ info={"error": str(e)},
+ )
return Trace(
- reward=reward,
done=True,
messages=messages,
- content=final_response.content if final_response else error,
- isError=is_error,
- citations=final_response.citations if final_response else [],
- info={"error": error} if error else {},
)
- def _recover_citations_from_content(self, response: InferenceResult) -> bool:
- """Try to extract citations from model content when native citations are missing.
-
- Handles two cases: raw JSON content and fenced ```json blocks.
- """
- raw = response.content or ""
- if not raw:
- return False
-
- # Try raw content first, then try extracting from fenced block.
- for text in dict.fromkeys([raw, self._extract_fenced_json(raw) or ""]):
- if not text:
- continue
- try:
- parsed = json.loads(text)
- except (json.JSONDecodeError, TypeError):
- continue
- if not isinstance(parsed, dict):
- continue
-
- raw_citations = parsed.get("citations")
- if not isinstance(raw_citations, list) or not raw_citations:
- continue
-
- normalized: list[Citation] = [
- c
- for cit in raw_citations
- if isinstance(cit, dict) and (c := self._normalize_citation(cit)) is not None
- ]
- if not normalized:
- continue
-
- content = parsed.get("content")
- if isinstance(content, str) and content.strip():
- response.content = content
- response.citations = [c.model_dump(exclude={"provider_data"}) for c in normalized]
- return True
-
- return False
-
- @staticmethod
- def _extract_fenced_json(value: str) -> str | None:
- """Extract JSON content from a fenced code block."""
- match = re.search(r"```(?:json)?\s*\n(.*?)```", value, re.DOTALL)
- return match.group(1).strip() if match else None
-
- @staticmethod
- def _normalize_citation(cit: dict[str, Any]) -> Citation | None:
- """Normalize a citation dict to canonical Citation shape.
-
- Maps common key aliases to canonical names and validates via Citation.
- Returns None only if construction fails (e.g. extra-forbid violation).
- """
- source = cit.get("source") or cit.get("document") or ""
- try:
- return Citation(
- type=cit.get("type", "document_citation"),
- text=cit.get("text") or cit.get("cited_text", ""),
- source=str(source),
- title=cit.get("title") or cit.get("document_title"),
- start_index=cit.get("start_index", cit.get("start_char_index")),
- end_index=cit.get("end_index", cit.get("end_char_index")),
- )
- except Exception:
- return None
-
- async def call_tools(
- self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
- ) -> list[MCPToolResult]:
- """
- Call tools through the bound EvalContext.
-
- Args:
- tool_call: MCPToolCall or list of MCPToolCall
-
- Returns:
- List of MCPToolResult
- """
- if tool_call is None:
- return []
-
- if isinstance(tool_call, MCPToolCall):
- tool_call = [tool_call]
-
- if self.ctx is None:
- raise ValueError("Agent not bound to context - call run(ctx) first")
-
- results: list[MCPToolResult] = []
- for tc in tool_call:
- try:
- self.console.debug(f"Calling tool: {tc}")
- result = await self.ctx.call_tool(tc)
- results.append(MCPToolResult(content=result.content, isError=result.isError))
- except TimeoutError as e:
- self.console.error_log(f"Tool execution timed out: {e}")
- raise
- except Exception as e:
- self.console.error_log(f"Tool execution failed: {e}")
- results.append(_format_error_result(str(e)))
- return results
-
@abstractmethod
- async def get_system_messages(self) -> list[types.ContentBlock]:
- """
- Get the system prompt.
- """
- raise NotImplementedError
-
- @abstractmethod
- async def get_response(self, messages: list[Any]) -> InferenceResult:
+ async def get_response(self, messages: list[ProviderMessageT]) -> AgentResponse:
"""
Get response from the model including any tool calls.
@@ -553,114 +198,6 @@ async def get_response(self, messages: list[Any]) -> InferenceResult:
raise NotImplementedError
@abstractmethod
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
- """
- Format a list of content blocks into a list of messages.
- """
+ async def format_messages(self, messages: list[types.PromptMessage]) -> list[ProviderMessageT]:
+ """Format MCP prompt messages into provider messages."""
raise NotImplementedError
-
- @abstractmethod
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[Any]:
- """
- Format tool results into messages for the model.
-
- Args:
- tool_calls: List of MCPToolCall objects that were executed
- tool_results: List of MCPToolResult objects from tool execution
-
- Returns:
- List of formatted messages to append to conversation
- """
- raise NotImplementedError
-
- async def format_message(
- self,
- message: str
- | list[str]
- | types.ContentBlock
- | list[types.ContentBlock]
- | list[str | types.ContentBlock],
- ) -> list[Any]: # maybe type messages as list[types.ContentBlock]
- """
- Convencience function.
-
- Format a single content message into a list of messages for the model.
- """
- blocks: list[types.ContentBlock] = []
- if not isinstance(message, list):
- message = [message]
-
- for m in message:
- if isinstance(m, str):
- blocks.append(types.TextContent(text=m, type="text"))
- elif isinstance(m, types.ContentBlock):
- blocks.append(m)
- else:
- raise ValueError(f"Invalid message type: {type(m)}")
-
- return await self.format_blocks(blocks)
-
- def get_available_tools(self) -> list[types.Tool]:
- """Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
- if self._available_tools is None:
- raise RuntimeError(
- "Tools have not been initialized. Call initialize() before accessing available tools." # noqa: E501
- )
- return self._available_tools
-
- def get_tool_schemas(self) -> list[dict]:
- """Get tool schemas in a format suitable for the model.
-
- Uses categorized tools so that skipped tools are excluded from schemas
- automatically. Falls back to get_available_tools() if called before
- categorization.
- """
- if self._initialized:
- tools = list(self._categorized_tools.generic)
- else:
- tools = self.get_available_tools()
-
- schemas = []
- for tool in tools:
- schema = {
- "name": tool.name,
- "description": tool.description,
- }
- if tool.inputSchema:
- schema["parameters"] = tool.inputSchema
- schemas.append(schema)
- return schemas
-
- async def _filter_messages(
- self,
- message_list: list[types.ContentBlock],
- include_types: list[
- Literal["text", "image", "audio", "resource_link", "embedded_resource"]
- ],
- ) -> list[types.ContentBlock]:
- """
- Filter a list of messages and return only the messages of the given types.
-
- Args:
- message_list: The list of messages to filter
- include_types: List of types to include (None = all types)
-
- Returns:
- List of messages in provider-specific format
- """
- return [message for message in message_list if message.type in include_types]
-
- async def _cleanup(self) -> None:
- """Cleanup resources."""
- # Clear context reference
- self.ctx = None
-
-
-def _format_error_result(error_message: str) -> MCPToolResult:
- return MCPToolResult(content=text_to_blocks(error_message), isError=True)
-
-
-def text_to_blocks(text: str) -> list[types.ContentBlock]:
- return [types.TextContent(text=text, type="text")]
diff --git a/hud/agents/claude/__init__.py b/hud/agents/claude/__init__.py
index ce90d2178..5d1c41a60 100644
--- a/hud/agents/claude/__init__.py
+++ b/hud/agents/claude/__init__.py
@@ -6,11 +6,6 @@
AsyncAnthropic,
AsyncAnthropicBedrock,
ClaudeAgent,
- base64_to_content_block,
- document_to_content_block,
- text_document_block,
- text_to_content_block,
- tool_use_content_block,
)
from .tools import ClaudeToolSearchTool, ClaudeWebFetchTool, ClaudeWebSearchTool
@@ -21,9 +16,4 @@
"ClaudeToolSearchTool",
"ClaudeWebFetchTool",
"ClaudeWebSearchTool",
- "base64_to_content_block",
- "document_to_content_block",
- "text_document_block",
- "text_to_content_block",
- "tool_use_content_block",
]
diff --git a/hud/agents/claude/agent.py b/hud/agents/claude/agent.py
index 71698d9d1..1d5274de4 100644
--- a/hud/agents/claude/agent.py
+++ b/hud/agents/claude/agent.py
@@ -5,50 +5,44 @@
import copy
import json
import logging
-from inspect import cleandoc
-from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
+from functools import cached_property
+from typing import TYPE_CHECKING, Literal, cast
-import mcp.types as types
+import mcp.types as mcp_types
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, Omit
from anthropic.types import CacheControlEphemeralParam
from anthropic.types.beta import (
BetaBase64ImageSourceParam,
BetaBase64PDFSourceParam,
- BetaContentBlockParam,
BetaImageBlockParam,
+ BetaMessage,
BetaMessageParam,
- BetaPlainTextSourceParam,
BetaRequestDocumentBlockParam,
+ BetaTextBlock,
BetaTextBlockParam,
- BetaToolParam,
- BetaToolResultBlockParam,
+ BetaToolChoiceAutoParam,
BetaToolUnionParam,
)
+from hud.agents import gateway
from hud.agents.base import MCPAgent
-from hud.agents.tools import (
- EnvironmentCapability,
- call_agent_tools,
- capabilities_metadata_from_context,
- discover_environment_capabilities,
- select_hosted_tools,
-)
-from hud.agents.types import ClaudeConfig, ClaudeCreateParams
+from hud.agents.types import ClaudeConfig
from hud.settings import settings
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-from hud.utils.hud_console import HUDConsole
+from hud.tools.types import Citation
+from hud.types import AgentResponse, MCPToolCall
from hud.utils.types import with_signature
-from .tools import ClaudeHostedTool, ClaudeTool, ClaudeToolSearchTool, claude_tools
-from .tools.settings import claude_tool_settings
+from .tools import ClaudeAgentTools
if TYPE_CHECKING:
- from collections.abc import Sequence
+ import mcp.types as types
+ from anthropic.types.beta import BetaTextCitation
logger = logging.getLogger(__name__)
+ClaudeImageMediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
-class ClaudeAgent(MCPAgent):
+class ClaudeAgent(MCPAgent[BetaMessageParam]):
"""
Claude agent that uses MCP servers for tool execution.
@@ -56,33 +50,21 @@ class ClaudeAgent(MCPAgent):
tools through MCP servers instead of direct implementation.
"""
- metadata: ClassVar[dict[str, Any] | None] = {
- "display_width": claude_tool_settings.COMPUTER_WIDTH,
- "display_height": claude_tool_settings.COMPUTER_HEIGHT,
- }
- config_cls: ClassVar[type[BaseAgentConfig]] = ClaudeConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for Claude."""
- return AgentType.CLAUDE
-
- @with_signature(ClaudeCreateParams)
+ @with_signature(ClaudeConfig)
@classmethod
- def create(cls, **kwargs: Any) -> ClaudeAgent: # pyright: ignore[reportIncompatibleMethodOverride]
- return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
+ def create(cls, **kwargs: object) -> ClaudeAgent: # pyright: ignore[reportIncompatibleMethodOverride]
+ return cls(ClaudeConfig.model_validate(kwargs))
- def __init__(self, params: ClaudeCreateParams | None = None, **kwargs: Any) -> None:
- super().__init__(params, **kwargs)
+ def __init__(self, config: ClaudeConfig | None = None) -> None:
+ config = config or ClaudeConfig()
+ super().__init__(config)
self.config: ClaudeConfig
model_client = self.config.model_client
if model_client is None:
# Default to HUD gateway when HUD_API_KEY is available
if settings.api_key:
- from hud.agents.gateway import build_gateway_client
-
- model_client = build_gateway_client("anthropic")
+ model_client = gateway.build_gateway_client("anthropic")
elif settings.anthropic_api_key:
model_client = AsyncAnthropic(api_key=settings.anthropic_api_key)
else:
@@ -95,511 +77,262 @@ def __init__(self, params: ClaudeCreateParams | None = None, **kwargs: Any) -> N
" access"
)
- self.anthropic_client: AsyncAnthropic | AsyncAnthropicBedrock = model_client
- self.max_tokens = self.config.max_tokens
- self.use_computer_beta = self.config.use_computer_beta
- self.hud_console = HUDConsole(logger=logger)
-
- # these will be initialized in _convert_tools_for_claude
- self.has_computer_tool = False
- self.tool_mapping: dict[str, str] = {}
- self.claude_tools: list[BetaToolUnionParam] = []
- self._claude_native_tools: dict[str, ClaudeTool] = {}
- self._environment_capabilities: dict[str, EnvironmentCapability] = {}
- self._required_betas: set[str] = set()
- self._tool_search_threshold: int | None = None
-
- def _on_tools_ready(self) -> None:
- """Build Claude-specific tool mappings after tools are discovered."""
- self._convert_tools_for_claude()
-
- def _discover_environment_capabilities(
- self, tools: list[types.Tool]
- ) -> dict[str, EnvironmentCapability]:
- return discover_environment_capabilities(
- tools,
- env_metadata=capabilities_metadata_from_context(self.ctx),
- name_fallbacks=claude_tools.name_fallbacks,
+ self.anthropic_client: AsyncAnthropic | AsyncAnthropicBedrock = cast(
+ "AsyncAnthropic | AsyncAnthropicBedrock", model_client
)
+ self.max_tokens = self.config.max_tokens
- async def get_system_messages(self) -> list[types.ContentBlock]:
- """No system messages for Claude because applied in get_response"""
- return []
-
- def _result_from_response_blocks(self, response_blocks: list[Any]) -> InferenceResult:
- """Extract text/tool calls/citations from Anthropic response blocks."""
- result = InferenceResult(content="", tool_calls=[], done=True)
- text_content = ""
- thinking_content = ""
- citations: list[dict[str, Any]] = []
-
- for block in response_blocks:
- block_type = getattr(block, "type", None)
- if block_type == "tool_use":
- block_input = getattr(block, "input", {})
- mcp_name = self.tool_mapping.get(
- getattr(block, "name", ""),
- getattr(block, "name", ""),
- )
- arguments = block_input if isinstance(block_input, dict) else block_input.__dict__
- tool_call = MCPToolCall(
- id=getattr(block, "id", ""),
- name=mcp_name,
- arguments=arguments,
- )
- result.tool_calls.append(tool_call)
- result.done = False
- elif block_type == "text":
- text = getattr(block, "text", "") or ""
- text_content += text
- block_citations = getattr(block, "citations", None) or []
- for cit in block_citations:
- cit_dict = {
- "type": "document_citation",
- "text": getattr(cit, "cited_text", "") or "",
- "source": (
- str(idx)
- if (idx := getattr(cit, "document_index", None)) is not None
- else getattr(cit, "document_title", "") or ""
- ),
- "title": getattr(cit, "document_title", None),
- "start_index": getattr(cit, "start_char_index", None),
- "end_index": getattr(cit, "end_char_index", None),
- }
- normalized = self._normalize_citation(cit_dict)
- if normalized is not None:
- citations.append(normalized.model_dump(exclude={"provider_data"}))
- elif block_type == "thinking":
- thinking = getattr(block, "thinking", "") or ""
- if thinking:
- if thinking_content:
- thinking_content += "\n"
- thinking_content += thinking
-
- result.content = text_content
- result.citations = citations
- if thinking_content:
- result.reasoning = thinking_content
- return result
-
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[BetaMessageParam]:
- """Format messages for Claude."""
- # Convert MCP content types to Anthropic content types
- anthropic_blocks: list[BetaContentBlockParam] = []
-
- for block in blocks:
- if isinstance(block, types.TextContent):
- # Only include fields that Anthropic expects
- anthropic_blocks.append(
- BetaTextBlockParam(
- type="text",
- text=block.text,
- )
- )
- elif isinstance(block, types.ImageContent):
- # Convert MCP ImageContent to Anthropic format
- anthropic_blocks.append(
- BetaImageBlockParam(
+ @cached_property
+ def tools(self) -> ClaudeAgentTools:
+ return ClaudeAgentTools()
+
+ async def format_messages(self, messages: list[types.PromptMessage]) -> list[BetaMessageParam]:
+ """Format MCP prompt messages for Claude."""
+ formatted: list[BetaMessageParam] = []
+ for message in messages:
+ match message.content:
+ case mcp_types.TextContent():
+ content = BetaTextBlockParam(type="text", text=message.content.text)
+ case mcp_types.ImageContent():
+ content = BetaImageBlockParam(
type="image",
source=BetaBase64ImageSourceParam(
type="base64",
- media_type=cast(
- "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']",
- block.mimeType,
- ),
- data=block.data,
+ media_type=cast("ClaudeImageMediaType", message.content.mimeType),
+ data=message.content.data,
+ ),
+ )
+ case mcp_types.EmbeddedResource(
+ resource=mcp_types.BlobResourceContents(mimeType="application/pdf") as resource
+ ):
+ content = BetaRequestDocumentBlockParam(
+ type="document",
+ source=BetaBase64PDFSourceParam(
+ type="base64",
+ media_type="application/pdf",
+ data=resource.blob,
),
)
+ case _:
+ raise ValueError(f"Unknown content block type: {type(message.content)}")
+ formatted.append(
+ BetaMessageParam(
+ role=message.role,
+ content=[content],
)
- else:
- raise ValueError(f"Unknown content block type: {type(block)}")
-
- return [BetaMessageParam(role="user", content=anthropic_blocks)]
-
- @staticmethod
- def _extract_invalid_tool_json(exc: Exception) -> str | None:
- """Extract malformed tool JSON payload from Anthropic stream errors.
-
- Returns None when the exception is unrelated to tool JSON parsing.
- """
- message = str(exc)
- parse_error_prefix = "Unable to parse tool parameter JSON from model."
- if parse_error_prefix not in message:
- return None
-
- marker = "JSON: "
- marker_index = message.find(marker)
- if marker_index == -1:
- return ""
-
- return message[marker_index + len(marker) :].strip()
-
- @staticmethod
- def _build_invalid_tool_json_retry_message(invalid_json: str) -> BetaMessageParam:
- """Build a user message prompting the model to re-emit valid tool JSON."""
- wrapped = json.dumps({"INVALID_JSON": invalid_json}, ensure_ascii=True)
- retry_text = (
- "Your previous tool-call arguments were invalid JSON and could not be parsed.\n"
- "Retry the same intended tool call once with valid JSON arguments only.\n"
- "Ensure all strings are quoted and all arrays/objects are valid JSON.\n"
- f"Malformed payload (wrapped): {wrapped}"
- )
- return BetaMessageParam(
- role="user",
- content=[text_to_content_block(retry_text)],
- )
+ )
+ return formatted
- async def get_response(self, messages: list[BetaMessageParam]) -> InferenceResult:
+ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse:
"""Get response from Claude including any tool calls."""
- messages_cached = self._add_prompt_caching(messages)
# Betas are collected during provider tool conversion.
# Only pass betas when non-empty; an empty list can produce an empty
# anthropic-beta header which the API rejects.
- betas: list[str] | Omit = list(self._required_betas) if self._required_betas else Omit()
+ betas: list[str] | Omit = (
+ list(self.tools.required_betas) if self.tools.required_betas else Omit()
+ )
+ tool_choice = BetaToolChoiceAutoParam(type="auto", disable_parallel_tool_use=True)
- effective_tools: list[BetaToolUnionParam] = list(self.claude_tools)
- if self._tool_search_threshold is not None:
- generic_count = sum(
- 1 for t in effective_tools if isinstance(t, dict) and "input_schema" in t
- )
- if generic_count > self._tool_search_threshold:
+ effective_tools: list[BetaToolUnionParam] = list(self.tools.params)
+ if self.tools.tool_search_threshold is not None:
+ generic_count = sum(1 for t in effective_tools if "input_schema" in t)
+ if generic_count > self.tools.tool_search_threshold:
logger.debug(
"tool_search: %d generic tools > threshold %d, applying defer_loading",
generic_count,
- self._tool_search_threshold,
+ self.tools.tool_search_threshold,
)
effective_tools = [
- {**t, "defer_loading": True}
- if isinstance(t, dict) and "input_schema" in t
- else t
+ {**t, "defer_loading": True} if "input_schema" in t else t
for t in effective_tools
]
- # Bedrock doesn't support .stream() - use create(stream=True) instead
- if isinstance(self.anthropic_client, AsyncAnthropicBedrock):
+ client = self.anthropic_client
+ response: BetaMessage | None = None
+ is_bedrock = isinstance(client, AsyncAnthropicBedrock)
+ invalid_json_failures = 0
+
+ for _ in range(1 if is_bedrock else 3):
+ messages_cached: list[BetaMessageParam] = copy.deepcopy(messages)
+ cache_control = CacheControlEphemeralParam(type="ephemeral")
+ if messages_cached and messages_cached[-1].get("role") == "user":
+ content = messages_cached[-1]["content"]
+ if isinstance(content, list):
+ for block in content:
+ if isinstance(block, dict) and block["type"] not in (
+ "redacted_thinking",
+ "thinking",
+ ):
+ cast("dict[str, object]", block)["cache_control"] = cache_control
+
try:
- response = await self.anthropic_client.beta.messages.create(
- model=self.config.model,
- system=self.system_prompt if self.system_prompt is not None else Omit(),
- max_tokens=self.max_tokens,
- messages=messages_cached,
- tools=effective_tools,
- tool_choice={"type": "auto", "disable_parallel_tool_use": True},
- betas=betas,
- )
- messages.append(BetaMessageParam(role="assistant", content=response.content))
- except ModuleNotFoundError:
- raise ValueError(
- "boto3 is required for AWS Bedrock. Use `pip install hud[bedrock]`"
- ) from None
- else:
- # Regular Anthropic client supports .stream()
- response = None
- invalid_json_failures = 0
- for _ in range(3):
- messages_cached = self._add_prompt_caching(messages)
- try:
- async with self.anthropic_client.beta.messages.stream(
+ if isinstance(client, AsyncAnthropicBedrock):
+ response = await client.beta.messages.create(
model=self.config.model,
system=self.system_prompt if self.system_prompt is not None else Omit(),
max_tokens=self.max_tokens,
messages=messages_cached,
tools=effective_tools,
- tool_choice={"type": "auto", "disable_parallel_tool_use": True},
+ tool_choice=tool_choice,
+ betas=betas,
+ )
+ else:
+ async with client.beta.messages.stream(
+ model=self.config.model,
+ system=self.system_prompt if self.system_prompt is not None else Omit(),
+ max_tokens=self.max_tokens,
+ messages=messages_cached,
+ tools=effective_tools,
+ tool_choice=tool_choice,
betas=betas,
) as stream:
- # allow backend to accumulate message content
async for _ in stream:
pass
- # get final message
response = await stream.get_final_message()
- messages.append(
- BetaMessageParam(
- role="assistant",
- content=response.content,
- )
- )
- break
- except ValueError as exc:
- invalid_json = self._extract_invalid_tool_json(exc)
- is_retryable = invalid_json is not None
- if not is_retryable:
- raise
-
- invalid_json_failures += 1
- if invalid_json_failures == 1:
- logger.warning(
- "Claude returned invalid streamed tool JSON; "
- "retrying same generation once"
- )
- continue
-
- if invalid_json_failures == 2:
- logger.warning(
- "Claude returned invalid streamed tool JSON twice; "
- "retrying once with INVALID_JSON guidance"
- )
- messages.append(self._build_invalid_tool_json_retry_message(invalid_json))
- continue
-
+ messages.append(BetaMessageParam(role="assistant", content=response.content))
+ break
+ except ModuleNotFoundError:
+ if is_bedrock:
+ raise ValueError(
+ "boto3 is required for AWS Bedrock. Use `pip install hud-python[bedrock]`"
+ ) from None
+ raise
+ except ValueError as exc:
+ message = str(exc)
+ if is_bedrock or "Unable to parse tool parameter JSON from model." not in message:
raise
- if response is None:
- raise ValueError("Claude response missing after stream retries")
+ marker = "JSON: "
+ marker_index = message.find(marker)
+ invalid_json = (
+ "" if marker_index == -1 else message[marker_index + len(marker) :].strip()
+ )
- # Process response
- result = self._result_from_response_blocks(list(response.content))
+ invalid_json_failures += 1
+ if invalid_json_failures == 1:
+ logger.warning(
+ "Claude returned invalid streamed tool JSON; retrying same generation once"
+ )
+ continue
+
+ if invalid_json_failures == 2:
+ wrapped = json.dumps({"INVALID_JSON": invalid_json}, ensure_ascii=True)
+ retry_text = (
+ "Your previous tool-call arguments were invalid JSON and could not be "
+ "parsed.\n"
+ "Retry the same intended tool call once with valid JSON arguments only.\n"
+ "Ensure all strings are quoted and all arrays/objects are valid JSON.\n"
+ f"Malformed payload (wrapped): {wrapped}"
+ )
+ logger.warning(
+ "Claude returned invalid streamed tool JSON twice; "
+ "retrying once with INVALID_JSON guidance"
+ )
+ messages.append(
+ BetaMessageParam(
+ role="user",
+ content=[BetaTextBlockParam(type="text", text=retry_text)],
+ )
+ )
+ continue
- return result
+ raise
- async def call_tools(
- self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
- ) -> list[MCPToolResult]:
- """Route Claude provider tools to their backing environment tools."""
- return await call_agent_tools(self, self._claude_native_tools, tool_call)
-
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[BetaMessageParam]:
- """Format tool results into Claude messages.
-
- Handles EmbeddedResource (PDFs), images, and text content.
- """
- citations_enabled = bool(
- getattr(self.ctx, "scenario_enable_citations", False) if self.ctx else False
- )
+ if response is None:
+ raise ValueError("Claude response missing after stream retries")
- # Process each tool result
- user_content: list[BetaToolResultBlockParam | BetaRequestDocumentBlockParam] = []
-
- for tool_call, result in zip(tool_calls, tool_results, strict=True):
- tool_use_id = tool_call.id
- if not tool_use_id:
- self.hud_console.warning(f"No tool_use_id found for {tool_call.name}")
- continue
-
- # Blocks placed inside the tool_result (text, images)
- claude_blocks: list[
- BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam
- ] = []
- # Citable document blocks placed as siblings after the tool_result
- # so Claude's citation system indexes them properly.
- sibling_docs: list[BetaRequestDocumentBlockParam] = []
-
- if result.isError:
- error_msg = "Tool execution failed"
- for content in result.content:
- if isinstance(content, types.TextContent):
- error_msg = content.text
- break
- claude_blocks.append(text_to_content_block(f"Error: {error_msg}"))
- else:
- for content in result.content:
- if isinstance(content, types.TextContent):
- claude_blocks.append(text_to_content_block(content.text))
- if citations_enabled:
- sibling_docs.append(
- text_document_block(content.text, title=tool_call.name)
- )
- elif isinstance(content, types.ImageContent):
- claude_blocks.append(
- base64_to_content_block(content.data, content.mimeType)
+ result = AgentResponse(content="", tool_calls=[], done=True)
+ text_content = ""
+ thinking_content = ""
+ citations: list[dict[str, object]] = []
+
+ for block in response.content:
+ match block.type:
+ case "tool_use":
+ tool_use = block
+ mcp_name = self.tools.name_map.get(tool_use.name, tool_use.name)
+ result.tool_calls.append(
+ MCPToolCall(
+ id=tool_use.id,
+ name=mcp_name,
+ arguments=dict(tool_use.input),
+ _meta=mcp_types.RequestParams.Meta.model_validate(
+ {"enable_citations": self.enable_citations}
+ ),
)
- elif isinstance(content, types.EmbeddedResource):
- resource = content.resource
- if (
- isinstance(resource, types.BlobResourceContents)
- and resource.mimeType == "application/pdf"
- ):
- claude_blocks.append(
- document_to_content_block(
- base64_data=resource.blob,
- )
- )
- if citations_enabled:
- sibling_docs.append(
- document_to_content_block(
- base64_data=resource.blob,
- enable_citations=True,
- )
- )
-
- user_content.append(tool_use_content_block(tool_use_id, claude_blocks))
- user_content.extend(sibling_docs)
-
- return [
- BetaMessageParam(
- role="user",
- content=user_content,
- )
- ]
-
- async def create_user_message(self, text: str) -> BetaMessageParam:
- """Create a user message in Claude's format."""
- return BetaMessageParam(role="user", content=text)
-
- def _convert_tools_for_claude(self) -> None:
- """Convert MCP tools to Claude API tools."""
- self.has_computer_tool = False
- self.tool_mapping: dict[str, str] = {}
- self.claude_tools: list[BetaToolUnionParam] = []
- self._claude_native_tools = {}
- self._required_betas: set[str] = set()
- self._tool_search_threshold = None
-
- categorized = self._categorized_tools
-
- capabilities = self._discover_environment_capabilities(self.get_available_tools())
- self._environment_capabilities = capabilities
- provider_backing_tools: set[str] = set()
-
- for capability in capabilities.values():
- if capability.name not in claude_tools.capabilities:
- continue
- claude_tool = claude_tools.tool_for_capability(capability, self.model)
- if claude_tool is None:
- continue
- provider_backing_tools.add(capability.tool_name)
- provider_name = getattr(claude_tool, "provider_name", claude_tool.name)
- self._claude_native_tools[provider_name] = claude_tool
- self.tool_mapping[provider_name] = provider_name
- self.claude_tools.append(claude_tool.to_params())
- if claude_tool.required_beta:
- self._required_betas.add(claude_tool.required_beta)
- if claude_tool.capability == "computer":
- self.has_computer_tool = True
- logger.debug(
- "Activated Claude %s capability from env tool %s",
- capability.name,
- capability.tool_name,
- )
+ )
+ result.done = False
+ case "text":
+ text = cast("BetaTextBlock", block)
+ text_content += text.text
+ for citation in text.citations or []:
+ normalized = self._citation(citation)
+ citations.append(normalized.model_dump(exclude={"provider_data"}))
+ case "thinking":
+ thinking = block
+ if thinking.thinking:
+ if thinking_content:
+ thinking_content += "\n"
+ thinking_content += thinking.thinking
+ case _:
+ continue
- configured_hosted = select_hosted_tools(
- self.config.hosted_tools,
- tool_type=ClaudeHostedTool,
- model=self.model,
- )
- for hosted_tool in configured_hosted:
- self.claude_tools.append(hosted_tool.to_params()) # type: ignore[arg-type]
- required_beta = getattr(hosted_tool, "required_beta", None)
- if required_beta:
- self._required_betas.add(required_beta)
- if isinstance(hosted_tool, ClaudeToolSearchTool):
- self._tool_search_threshold = hosted_tool.threshold
-
- # Process generic tools
- for tool in categorized.generic:
- if tool.name in provider_backing_tools:
- continue
- if tool.description is None or tool.inputSchema is None:
- raise ValueError(
- cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema.
- Add these by:
- 1. Adding a docstring to your @mcp.tool decorated function for the description
- 2. Using pydantic Field() annotations on function parameters for the schema
- """)
- )
+ result.content = text_content
+ result.citations = citations
+ if thinking_content:
+ result.reasoning = thinking_content
- claude_tool = BetaToolParam(
- name=tool.name,
- description=tool.description,
- input_schema=tool.inputSchema,
- eager_input_streaming=True,
- )
- self.tool_mapping[tool.name] = tool.name
- self.claude_tools.append(claude_tool)
+ return result
- # Log actual tools being used
- tool_names = sorted(self.tool_mapping.keys())
- self.console.info(
- f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}"
+ @staticmethod
+ def _citation(citation: BetaTextCitation) -> Citation:
+ match citation.type:
+ case "char_location":
+ char_location = citation
+ citation_type = "document_citation"
+ text = char_location.cited_text
+ source = str(char_location.document_index)
+ title = char_location.document_title
+ start_index = char_location.start_char_index
+ end_index = char_location.end_char_index
+ case "page_location":
+ page_location = citation
+ citation_type = "document_citation"
+ text = page_location.cited_text
+ source = str(page_location.document_index)
+ title = page_location.document_title
+ start_index = None
+ end_index = None
+ case "content_block_location":
+ block_location = citation
+ citation_type = "document_citation"
+ text = block_location.cited_text
+ source = str(block_location.document_index)
+ title = block_location.document_title
+ start_index = block_location.start_block_index
+ end_index = block_location.end_block_index
+ case "search_result_location":
+ search_result = citation
+ citation_type = "search_result_location"
+ text = search_result.cited_text
+ source = search_result.source
+ title = search_result.title
+ start_index = search_result.start_block_index
+ end_index = search_result.end_block_index
+ case "web_search_result_location":
+ web_result = citation
+ citation_type = "web_search_result_location"
+ text = web_result.cited_text
+ source = web_result.url
+ title = web_result.title
+ start_index = None
+ end_index = None
+
+ return Citation(
+ type=citation_type,
+ text=text,
+ source=source,
+ title=title,
+ start_index=start_index,
+ end_index=end_index,
)
-
- def _add_prompt_caching(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]:
- """Add prompt caching to messages."""
- messages_cached = copy.deepcopy(messages)
- cache_control = CacheControlEphemeralParam(type="ephemeral")
-
- # Mark last user message with cache control
- if (
- messages_cached
- and isinstance(messages_cached[-1], dict)
- and messages_cached[-1].get("role") == "user"
- ):
- last_content = messages_cached[-1]["content"]
- # Content is formatted to be list of ContentBlock in format_blocks and format_message
- if isinstance(last_content, list):
- for block in last_content:
- # Only add cache control to dict-like block types that support it
- if isinstance(block, dict):
- match block["type"]:
- case "redacted_thinking" | "thinking":
- pass
- case _:
- block["cache_control"] = cache_control
-
- return messages_cached
-
-
-def base64_to_content_block(
- base64: str,
- media_type: str = "image/png",
-) -> BetaImageBlockParam:
- """Convert base64 image to Claude content block."""
- return BetaImageBlockParam(
- type="image",
- source=BetaBase64ImageSourceParam(
- type="base64",
- media_type=cast(
- "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']",
- media_type,
- ),
- data=base64,
- ),
- )
-
-
-def text_to_content_block(text: str) -> BetaTextBlockParam:
- """Convert text to Claude content block."""
- return {"type": "text", "text": text}
-
-
-def text_document_block(text: str, *, title: str | None = None) -> BetaRequestDocumentBlockParam:
- """Wrap plain text as a citable document block."""
- block = BetaRequestDocumentBlockParam(
- type="document",
- source=BetaPlainTextSourceParam(
- type="text",
- media_type="text/plain",
- data=text,
- ),
- citations={"enabled": True},
- )
- if title:
- block["title"] = title
- return block
-
-
-def document_to_content_block(
- base64_data: str, *, enable_citations: bool = False
-) -> BetaRequestDocumentBlockParam:
- """Convert base64 PDF to Claude document content block."""
- block = BetaRequestDocumentBlockParam(
- type="document",
- source=BetaBase64PDFSourceParam(
- type="base64",
- media_type="application/pdf",
- data=base64_data,
- ),
- )
- if enable_citations:
- block["citations"] = {"enabled": True}
- return block
-
-
-def tool_use_content_block(
- tool_use_id: str,
- content: Sequence[BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam],
-) -> BetaToolResultBlockParam:
- """Create tool result content block."""
- return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content} # pyright: ignore[reportReturnType]
diff --git a/hud/agents/claude/tools/__init__.py b/hud/agents/claude/tools/__init__.py
index ff341fa43..16796f567 100644
--- a/hud/agents/claude/tools/__init__.py
+++ b/hud/agents/claude/tools/__init__.py
@@ -2,58 +2,63 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, ClassVar
-from hud.agents.tools import AgentToolRegistry
+from anthropic.types.beta import BetaToolUnionParam
-from .base import ClaudeTool
+from hud.agents.tools import AgentTools
+
+from .base import ClaudeFunctionTool, ClaudeTool
from .coding import ClaudeBashTool, ClaudeTextEditorTool
from .computer import ClaudeComputerTool
from .hosted import ClaudeHostedTool, ClaudeToolSearchTool, ClaudeWebFetchTool, ClaudeWebSearchTool
from .memory import ClaudeMemoryTool
+if TYPE_CHECKING:
+ from collections.abc import Mapping
+
+ from hud.agents.tools import AgentTool
+
-@dataclass(frozen=True)
-class ClaudeToolRegistry(AgentToolRegistry[ClaudeTool]):
- """Registry for Claude harness tools."""
+class ClaudeAgentTools(AgentTools[ClaudeTool, BetaToolUnionParam]):
+ """Prepared Claude tool state for a run."""
- tool_classes: tuple[type[ClaudeTool], ...] = (
+ native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = (
ClaudeComputerTool,
ClaudeBashTool,
ClaudeTextEditorTool,
ClaudeMemoryTool,
)
- name_fallbacks: dict[str, tuple[str, ...]] = field(
- default_factory=lambda: {
- "computer": ("computer", "anthropic_computer", "computer_anthropic"),
- "shell": ("bash",),
- "editor": ("edit", "str_replace_based_edit_tool", "text_editor"),
- "memory": ("memory",),
- }
- )
+ function_tool_class = ClaudeFunctionTool
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {
+ "computer": ("computer", "anthropic_computer", "computer_anthropic"),
+ "shell": ("bash",),
+ "editor": ("edit", "str_replace_based_edit_tool", "text_editor"),
+ "memory": ("memory",),
+ }
- @property
- def capabilities(self) -> frozenset[str]:
- return frozenset(cls.capability for cls in self.tool_classes)
-
- @property
- def provider_tool_names(self) -> frozenset[str]:
- return frozenset(cls.name for cls in self.tool_classes)
+ def __init__(self) -> None:
+ super().__init__()
+ self.required_betas: set[str] = set()
+ def prepare(self, **kwargs: Any) -> None:
+ super().prepare(**kwargs)
+ self.required_betas = {
+ required_beta for tool in self.values() if (required_beta := tool.required_beta)
+ }
-claude_tools = ClaudeToolRegistry()
+ @property
+ def tool_search_threshold(self) -> int | None:
+ for hosted_tool in self.hosted_tools:
+ if isinstance(hosted_tool, ClaudeToolSearchTool):
+ return hosted_tool.threshold
+ return None
__all__ = [
- "ClaudeBashTool",
- "ClaudeComputerTool",
+ "ClaudeAgentTools",
"ClaudeHostedTool",
- "ClaudeMemoryTool",
- "ClaudeTextEditorTool",
- "ClaudeTool",
- "ClaudeToolRegistry",
"ClaudeToolSearchTool",
"ClaudeWebFetchTool",
"ClaudeWebSearchTool",
- "claude_tools",
]
diff --git a/hud/agents/claude/tools/base.py b/hud/agents/claude/tools/base.py
index ee4b4820e..0cd353cad 100644
--- a/hud/agents/claude/tools/base.py
+++ b/hud/agents/claude/tools/base.py
@@ -2,27 +2,179 @@
from __future__ import annotations
-from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any
+from dataclasses import dataclass
+from inspect import cleandoc
+from typing import TYPE_CHECKING, Any, Literal, cast
-from hud.agents import tools as _agent_tools
-from hud.agents.tools import AgentTool, AgentToolSpec, CallTool
+import mcp.types as types
+from anthropic.types.beta import (
+ BetaBase64ImageSourceParam,
+ BetaBase64PDFSourceParam,
+ BetaImageBlockParam,
+ BetaMessageParam,
+ BetaPlainTextSourceParam,
+ BetaRequestDocumentBlockParam,
+ BetaTextBlockParam,
+ BetaToolParam,
+ BetaToolResultBlockParam,
+)
+
+from hud.agents.tools import AgentTool, AgentToolSpec
if TYPE_CHECKING:
from anthropic.types.beta import BetaToolUnionParam
- from hud.types import MCPToolResult
+ from hud.types import MCPToolCall, MCPToolResult
else:
BetaToolUnionParam = Any
-ClaudeToolSpec = AgentToolSpec
-call_tool = _agent_tools.call_tool
+ClaudeImageMediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
+ClaudeToolResultContent = BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam
+
+
+@dataclass(frozen=True)
+class ClaudeToolSpec(AgentToolSpec):
+ """Claude provider tool definition."""
+
+ beta: str | None = None
-class ClaudeTool(AgentTool["BetaToolUnionParam"], ABC):
+class ClaudeTool(AgentTool["BetaToolUnionParam"]):
"""Agent-side Claude provider tool backed by an environment tool."""
- @abstractmethod
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- """Execute against the environment tool using the agent-provided caller."""
- ...
+ def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None:
+ super().__init__(env_tool_name=env_tool_name, spec=spec)
+ self.spec: ClaudeToolSpec = spec
+
+ @property
+ def required_beta(self) -> str | None:
+ return self.spec.beta
+
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> BetaMessageParam | None:
+ tool_use_id = call.id
+ if not tool_use_id:
+ return None
+
+ result_content = result.content
+ if result.isError:
+ error_msg = next(
+ (
+ content.text
+ for content in result.content
+ if isinstance(content, types.TextContent)
+ ),
+ "Tool execution failed",
+ )
+ result_content = [types.TextContent(type="text", text=f"Error: {error_msg}")]
+
+ claude_blocks: list[ClaudeToolResultContent] = []
+ sibling_docs: list[BetaRequestDocumentBlockParam] = []
+ enable_citations = bool(getattr(call.meta, "enable_citations", False))
+ for content in result_content:
+ citation_doc = None
+ match content:
+ case types.TextContent():
+ block = BetaTextBlockParam(type="text", text=content.text)
+ if enable_citations and not result.isError:
+ citation_doc = BetaRequestDocumentBlockParam(
+ type="document",
+ source=BetaPlainTextSourceParam(
+ type="text",
+ media_type="text/plain",
+ data=content.text,
+ ),
+ title=call.name,
+ citations={"enabled": True},
+ )
+ case types.ImageContent():
+ block = BetaImageBlockParam(
+ type="image",
+ source=BetaBase64ImageSourceParam(
+ type="base64",
+ media_type=cast("ClaudeImageMediaType", content.mimeType),
+ data=content.data,
+ ),
+ )
+ case types.EmbeddedResource(
+ resource=types.BlobResourceContents(mimeType="application/pdf") as resource
+ ):
+ block = BetaRequestDocumentBlockParam(
+ type="document",
+ source=BetaBase64PDFSourceParam(
+ type="base64",
+ media_type="application/pdf",
+ data=resource.blob,
+ ),
+ )
+ if enable_citations and not result.isError:
+ citation_doc = BetaRequestDocumentBlockParam(
+ type="document",
+ source=block["source"],
+ citations={"enabled": True},
+ )
+ case _:
+ raise ValueError(f"Unknown content block type: {type(content)}")
+ claude_blocks.append(block)
+ if citation_doc is not None:
+ sibling_docs.append(citation_doc)
+
+ return BetaMessageParam(
+ role="user",
+ content=[
+ BetaToolResultBlockParam(
+ type="tool_result",
+ tool_use_id=tool_use_id,
+ content=claude_blocks,
+ ),
+ *sibling_docs,
+ ],
+ )
+
+
+class ClaudeFunctionTool(ClaudeTool):
+ """Regular environment tool exposed as a Claude function tool."""
+
+ name = "function"
+ capability = "function"
+
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
+ description: str,
+ input_schema: dict[str, Any],
+ ) -> None:
+ super().__init__(
+ env_tool_name=env_tool_name,
+ spec=ClaudeToolSpec(api_type="function", api_name=env_tool_name),
+ )
+ self.description = description
+ self.input_schema = input_schema
+
+ @classmethod
+ def from_tool(cls, tool: types.Tool) -> ClaudeFunctionTool:
+ if tool.description is None:
+ raise ValueError(
+ cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema.
+ Add these by:
+ 1. Adding a docstring to your @mcp.tool decorated function for the description
+ 2. Using pydantic Field() annotations on function parameters for the schema
+ """)
+ )
+ return cls(
+ env_tool_name=tool.name,
+ description=tool.description,
+ input_schema=tool.inputSchema,
+ )
+
+ @property
+ def provider_name(self) -> str:
+ return self.env_tool_name
+
+ def to_params(self) -> BetaToolUnionParam:
+ return BetaToolParam(
+ name=self.provider_name,
+ description=self.description,
+ input_schema=self.input_schema,
+ eager_input_streaming=True,
+ )
diff --git a/hud/agents/claude/tools/coding.py b/hud/agents/claude/tools/coding.py
index f9b4331dd..fc66467c8 100644
--- a/hud/agents/claude/tools/coding.py
+++ b/hud/agents/claude/tools/coding.py
@@ -8,7 +8,7 @@
from hud.types import MCPToolResult
-from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool
+from .base import ClaudeTool, ClaudeToolSpec
if TYPE_CHECKING:
from anthropic.types.beta import (
@@ -16,6 +16,8 @@
BetaToolTextEditor20250728Param,
)
+ from hud.agents.tools.base import CallTool
+
CLAUDE_BASH_SPEC = ClaudeToolSpec(
api_type="bash_20250124",
@@ -29,30 +31,18 @@
),
)
-CLAUDE_TEXT_EDITOR_SPECS: tuple[ClaudeToolSpec, ...] = (
- ClaudeToolSpec(
- api_type="text_editor_20250728",
- api_name="str_replace_based_edit_tool",
- supported_models=(
- "*claude-opus-4-7*",
- "*claude-opus-4-6*",
- "*claude-sonnet-4-5*",
- "*claude-sonnet-4-6*",
- "*claude-haiku-4-5*",
- ),
+CLAUDE_TEXT_EDITOR_SPEC = ClaudeToolSpec(
+ api_type="text_editor_20250728",
+ api_name="str_replace_based_edit_tool",
+ supported_models=(
+ "*claude-opus-4-7*",
+ "*claude-opus-4-6*",
+ "*claude-sonnet-4-5*",
+ "*claude-sonnet-4-6*",
+ "*claude-haiku-4-5*",
),
)
-CLAUDE_TEXT_EDITOR_SPEC = CLAUDE_TEXT_EDITOR_SPECS[0]
-
-CLAUDE_TEXT_EDITOR_NAMES = {
- "text_editor_20250728": "str_replace_based_edit_tool",
-}
-
-CLAUDE_TEXT_EDITOR_COMMANDS = {
- "text_editor_20250728": frozenset({"view", "create", "str_replace", "insert"}),
-}
-
class ClaudeBashTool(ClaudeTool):
"""Claude bash provider tool backed by an environment shell tool."""
@@ -81,7 +71,7 @@ def to_params(self) -> BetaToolBash20250124Param:
async def execute(
self,
- caller: CallTool,
+ call_tool: CallTool,
arguments: dict[str, Any],
) -> MCPToolResult:
if not arguments.get("restart") and "command" not in arguments:
@@ -94,7 +84,7 @@ async def execute(
],
isError=True,
)
- return await call_tool(caller, self.env_tool_name, arguments)
+ return await super().execute(call_tool, arguments)
class ClaudeTextEditorTool(ClaudeTool):
@@ -105,9 +95,8 @@ class ClaudeTextEditorTool(ClaudeTool):
@classmethod
def default_spec(cls, model: str) -> ClaudeToolSpec | None:
- for spec in CLAUDE_TEXT_EDITOR_SPECS:
- if spec.supports_model(model):
- return spec
+ if CLAUDE_TEXT_EDITOR_SPEC.supports_model(model):
+ return CLAUDE_TEXT_EDITOR_SPEC
return None
def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None:
@@ -115,7 +104,7 @@ def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None:
@property
def provider_name(self) -> str:
- return CLAUDE_TEXT_EDITOR_NAMES.get(self.spec.api_type, self.spec.api_name)
+ return self.spec.api_name
def to_params(self) -> BetaToolTextEditor20250728Param:
return cast(
@@ -128,25 +117,10 @@ def to_params(self) -> BetaToolTextEditor20250728Param:
async def execute(
self,
- caller: CallTool,
+ call_tool: CallTool,
arguments: dict[str, Any],
) -> MCPToolResult:
- command = arguments.get("command")
- allowed_commands = CLAUDE_TEXT_EDITOR_COMMANDS.get(self.spec.api_type)
- if allowed_commands is not None and command not in allowed_commands:
- return MCPToolResult(
- content=[
- TextContent(
- type="text",
- text=(
- f"{self.spec.api_type} does not support command {command!r}. "
- f"Supported commands: {', '.join(sorted(allowed_commands))}"
- ),
- )
- ],
- isError=True,
- )
- return await call_tool(caller, self.env_tool_name, _claude_editor_arguments(arguments))
+ return await super().execute(call_tool, _claude_editor_arguments(arguments))
def _claude_editor_arguments(arguments: dict[str, Any]) -> dict[str, Any]:
@@ -170,12 +144,3 @@ def _claude_editor_arguments(arguments: dict[str, Any]) -> dict[str, Any]:
}
case _:
return dict(arguments)
-
-
-__all__ = [
- "CLAUDE_BASH_SPEC",
- "CLAUDE_TEXT_EDITOR_SPEC",
- "CLAUDE_TEXT_EDITOR_SPECS",
- "ClaudeBashTool",
- "ClaudeTextEditorTool",
-]
diff --git a/hud/agents/claude/tools/computer.py b/hud/agents/claude/tools/computer.py
index 6953e2fde..7ca775c15 100644
--- a/hud/agents/claude/tools/computer.py
+++ b/hud/agents/claude/tools/computer.py
@@ -9,13 +9,19 @@
import base64
import logging
from io import BytesIO
-from typing import TYPE_CHECKING, Any, Literal, cast
+from typing import TYPE_CHECKING, Any, cast
-from mcp.types import ImageContent, TextContent
+from mcp.types import ImageContent
+from hud.agents.tools.computer import (
+ computer_error_result,
+ computer_tool_info,
+ execute_computer_calls,
+ first_image_data,
+)
from hud.types import MCPToolResult
-from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool
+from .base import ClaudeTool, ClaudeToolSpec
from .settings import claude_tool_settings
if TYPE_CHECKING:
@@ -25,6 +31,7 @@
)
from hud.agents.tools import EnvironmentCapability
+ from hud.agents.tools.base import CallTool
logger = logging.getLogger(__name__)
@@ -86,8 +93,6 @@
),
)
-_AUTO_SCREENSHOT_OFF_SPECS = {"computer_20251124"}
-
class ClaudeComputerTool(ClaudeTool):
"""Translate Claude native computer calls into environment computer calls."""
@@ -107,62 +112,36 @@ def __init__(
*,
env_tool_name: str,
spec: ClaudeToolSpec,
- model: str,
display_width: int,
display_height: int,
- schema: Literal["hud", "anthropic"],
) -> None:
- super().__init__(env_tool_name=env_tool_name, spec=self._resolve_spec(spec, model))
+ super().__init__(env_tool_name=env_tool_name, spec=spec)
self.display_width = display_width
self.display_height = display_height
- self.schema = schema
@classmethod
def from_capability(
cls,
capability: EnvironmentCapability,
- spec: ClaudeToolSpec,
model: str,
- ) -> ClaudeComputerTool:
- tool = capability.tool
- props = tool.inputSchema.get("properties", {}) if isinstance(tool.inputSchema, dict) else {}
- schema: Literal["hud", "anthropic"] = (
- "anthropic" if {"coordinate", "scroll_direction"} & set(props) else "hud"
- )
-
- metadata_resolution = capability.metadata.get("resolution", {})
- if not isinstance(metadata_resolution, dict):
- metadata_resolution = {}
- resolution = (tool.meta or {}).get("resolution", {}) if tool.meta else {}
- display_width = int(
- metadata_resolution.get("width")
- or resolution.get("width")
- or claude_tool_settings.COMPUTER_WIDTH
- )
- display_height = int(
- metadata_resolution.get("height")
- or resolution.get("height")
- or claude_tool_settings.COMPUTER_HEIGHT
+ ) -> ClaudeComputerTool | None:
+ spec = cls.default_spec(model)
+ if spec is None:
+ return None
+
+ computer_info = computer_tool_info(
+ capability.tool,
+ default_width=claude_tool_settings.COMPUTER_WIDTH,
+ default_height=claude_tool_settings.COMPUTER_HEIGHT,
)
return cls(
env_tool_name=capability.tool_name,
spec=spec,
- model=model,
- display_width=display_width,
- display_height=display_height,
- schema=schema,
+ display_width=computer_info.display_width,
+ display_height=computer_info.display_height,
)
- @staticmethod
- def _resolve_spec(spec: ClaudeToolSpec, model: str) -> ClaudeToolSpec:
- if spec.api_type and spec.api_type.startswith("computer_"):
- return spec
- for candidate in CLAUDE_COMPUTER_SPECS:
- if candidate.supports_model(model):
- return candidate
- return spec
-
def to_params(
self,
) -> BetaToolComputerUse20250124Param | BetaToolComputerUse20251124Param:
@@ -191,47 +170,20 @@ def to_params(
async def execute(
self,
- caller: CallTool,
- arguments: dict[str, Any],
- ) -> MCPToolResult:
- if self.schema == "anthropic":
- return await self._call_env(caller, self._as_anthropic_arguments(arguments))
- return await self._call_env_tool(caller, arguments)
-
- async def _call_env(
- self,
- caller: CallTool,
- arguments: dict[str, Any],
- ) -> MCPToolResult:
- return await call_tool(caller, self.env_tool_name, arguments)
-
- async def _call_env_tool(
- self,
- caller: CallTool,
+ call_tool: CallTool,
arguments: dict[str, Any],
) -> MCPToolResult:
action = arguments.get("action")
if action == "zoom":
- return await self._zoom(caller, arguments)
-
- calls = self._env_calls(arguments)
- result = MCPToolResult(content=[], isError=False)
- for call in calls:
- result = await self._call_env(caller, call)
- if result.isError:
- return result
- return result
-
- def _as_anthropic_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
- args = dict(arguments)
- if (
- self.spec.api_type in _AUTO_SCREENSHOT_OFF_SPECS
- and args.get("action") != "screenshot"
- and "take_screenshot_on_click" not in args
- ):
- args["take_screenshot_on_click"] = False
- return args
+ return await self._zoom(call_tool, arguments)
+
+ return await execute_computer_calls(
+ call_tool,
+ env_tool_name=self.env_tool_name,
+ calls=self._env_calls(arguments),
+ ensure_screenshot=False,
+ )
def _env_calls(self, arguments: dict[str, Any]) -> list[dict[str, Any]]:
action = arguments.get("action")
@@ -239,8 +191,10 @@ def _env_calls(self, arguments: dict[str, Any]) -> list[dict[str, Any]]:
text = arguments.get("text")
def xy() -> tuple[int | None, int | None]:
- if isinstance(coordinate, list) and len(coordinate) >= 2:
- return coordinate[0], coordinate[1]
+ if isinstance(coordinate, list):
+ coords = cast("list[Any]", coordinate)
+ if len(coords) >= 2:
+ return int(coords[0]), int(coords[1])
return None, None
if action == "screenshot":
@@ -317,17 +271,21 @@ def xy() -> tuple[int | None, int | None]:
]
if action in ("left_click_drag", "drag"):
start = arguments.get("start_coordinate")
- path = []
- if isinstance(start, list) and len(start) >= 2:
- path.append({"x": start[0], "y": start[1]})
- if isinstance(coordinate, list) and len(coordinate) >= 2:
- if not path:
- return [
- {"action": "mouse_down", "button": "left"},
- {"action": "move", "x": coordinate[0], "y": coordinate[1]},
- {"action": "mouse_up", "button": "left"},
- ]
- path.append({"x": coordinate[0], "y": coordinate[1]})
+ path: list[dict[str, Any]] = []
+ if isinstance(start, list):
+ start_coords = cast("list[Any]", start)
+ if len(start_coords) >= 2:
+ path.append({"x": start_coords[0], "y": start_coords[1]})
+ if isinstance(coordinate, list):
+ end_coords = cast("list[Any]", coordinate)
+ if len(end_coords) >= 2:
+ if not path:
+ return [
+ {"action": "mouse_down", "button": "left"},
+ {"action": "move", "x": end_coords[0], "y": end_coords[1]},
+ {"action": "mouse_up", "button": "left"},
+ ]
+ path.append({"x": end_coords[0], "y": end_coords[1]})
return [{"action": "drag", "path": path, "hold_keys": self._hold_keys(text)}]
if action == "wait":
duration = arguments.get("duration") or 0
@@ -351,28 +309,23 @@ def xy() -> tuple[int | None, int | None]:
async def _zoom(
self,
- caller: CallTool,
+ call_tool: CallTool,
arguments: dict[str, Any],
) -> MCPToolResult:
region = arguments.get("region")
- if not isinstance(region, (list, tuple)) or len(region) != 4:
- return MCPToolResult(
- content=[TextContent(type="text", text="region must be [x0, y0, x1, y1]")],
- isError=True,
- )
+ region_value = cast("list[Any] | tuple[Any, ...]", region)
+ if not isinstance(region, (list, tuple)) or len(region_value) != 4:
+ return computer_error_result("region must be [x0, y0, x1, y1]")
- screenshot = await self._call_env(caller, {"action": "screenshot"})
+ screenshot = await super().execute(call_tool, {"action": "screenshot"})
if screenshot.isError:
return screenshot
- image_data = _first_image(screenshot)
+ image_data = first_image_data(screenshot)
if image_data is None:
- return MCPToolResult(
- content=[TextContent(type="text", text="screenshot returned no image")],
- isError=True,
- )
+ return computer_error_result("screenshot returned no image")
try:
- x0, y0, x1, y1 = (int(v) for v in region)
+ x0, y0, x1, y1 = (int(v) for v in region_value)
image = ImageContent(
type="image",
mimeType="image/png",
@@ -381,7 +334,7 @@ async def _zoom(
return MCPToolResult(content=[image], isError=False)
except Exception as exc:
logger.warning("Claude computer zoom failed: %s", exc)
- return MCPToolResult(content=[TextContent(type="text", text=str(exc))], isError=True)
+ return computer_error_result(str(exc))
@staticmethod
def _keys(text: str | None) -> list[str]:
@@ -419,13 +372,6 @@ def _map_key(key: str) -> str:
return ANTHROPIC_TO_CLA_KEYS.get(key, ANTHROPIC_TO_CLA_KEYS.get(key.capitalize(), key.lower()))
-def _first_image(result: MCPToolResult) -> str | None:
- for block in result.content or []:
- if isinstance(block, ImageContent):
- return block.data
- return None
-
-
def _crop_png(image_data: str, region: tuple[int, int, int, int]) -> str:
from PIL import Image # type: ignore[import-not-found]
@@ -434,6 +380,3 @@ def _crop_png(image_data: str, region: tuple[int, int, int, int]) -> str:
buffer = BytesIO()
crop.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("ascii")
-
-
-__all__ = ["CLAUDE_COMPUTER_SPECS", "ClaudeComputerTool"]
diff --git a/hud/agents/claude/tools/hosted.py b/hud/agents/claude/tools/hosted.py
index f9b19a593..050afedaa 100644
--- a/hud/agents/claude/tools/hosted.py
+++ b/hud/agents/claude/tools/hosted.py
@@ -112,11 +112,3 @@ def _validate_domain_filters(
) -> None:
if allowed_domains and blocked_domains:
raise ValueError("Use either allowed_domains or blocked_domains, not both.")
-
-
-__all__ = [
- "ClaudeHostedTool",
- "ClaudeToolSearchTool",
- "ClaudeWebFetchTool",
- "ClaudeWebSearchTool",
-]
diff --git a/hud/agents/claude/tools/memory.py b/hud/agents/claude/tools/memory.py
index 53d8c42d5..373c4f3c7 100644
--- a/hud/agents/claude/tools/memory.py
+++ b/hud/agents/claude/tools/memory.py
@@ -2,15 +2,13 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, cast
-from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool
+from .base import ClaudeTool, ClaudeToolSpec
if TYPE_CHECKING:
from anthropic.types.beta import BetaToolUnionParam
- from hud.types import MCPToolResult
-
CLAUDE_MEMORY_SPEC = ClaudeToolSpec(
api_type="memory_20250818",
@@ -48,13 +46,3 @@ def to_params(self) -> BetaToolUnionParam:
"name": self.name,
},
)
-
- async def execute(
- self,
- caller: CallTool,
- arguments: dict[str, Any],
- ) -> MCPToolResult:
- return await call_tool(caller, self.env_tool_name, arguments)
-
-
-__all__ = ["CLAUDE_MEMORY_SPEC", "ClaudeMemoryTool"]
diff --git a/hud/agents/claude/tools/settings.py b/hud/agents/claude/tools/settings.py
index 041c436c4..9a301d006 100644
--- a/hud/agents/claude/tools/settings.py
+++ b/hud/agents/claude/tools/settings.py
@@ -34,5 +34,3 @@ class ClaudeToolSettings(BaseSettings):
claude_tool_settings = ClaudeToolSettings()
-
-__all__ = ["ClaudeToolSettings", "claude_tool_settings"]
diff --git a/hud/agents/gateway.py b/hud/agents/gateway.py
index 4d0973f8f..c78db083b 100644
--- a/hud/agents/gateway.py
+++ b/hud/agents/gateway.py
@@ -2,10 +2,44 @@
from __future__ import annotations
-from typing import Any
+from typing import TYPE_CHECKING, Any
+import httpx
+from openai import AsyncOpenAI
+from pydantic import BaseModel, Field
-def build_gateway_client(provider: str) -> Any:
+from hud.settings import settings
+from hud.types import AgentType
+
+if TYPE_CHECKING:
+ from typing import TypeAlias
+
+ from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
+ from google.genai import Client as GenaiClient
+
+ from hud.agents.base import MCPAgent
+
+ GatewayClient: TypeAlias = AsyncAnthropic | AsyncAnthropicBedrock | GenaiClient | AsyncOpenAI
+
+
+class GatewayProviderInfo(BaseModel):
+ name: str | None = None
+ default_sdk_agent_type: str | None = None
+
+
+class GatewayModelInfo(BaseModel):
+ id: str | None = None
+ name: str | None = None
+ model_name: str | None = None
+ sdk_agent_type: str | None = None
+ provider: GatewayProviderInfo = Field(default_factory=GatewayProviderInfo)
+
+
+class GatewayModelsResponse(BaseModel):
+ models: list[GatewayModelInfo]
+
+
+def build_gateway_client(provider: str) -> GatewayClient:
"""Build a client configured for HUD gateway routing.
Args:
@@ -14,10 +48,10 @@ def build_gateway_client(provider: str) -> Any:
Returns:
Configured async client for the provider.
"""
- from hud.settings import settings
-
provider = provider.lower()
+ # Anthropic and Gemini SDKs are optional extras; keep those imports on the
+ # provider branch so importing gateway utilities does not require both.
if provider == "anthropic":
from anthropic import AsyncAnthropic
@@ -37,6 +71,74 @@ def build_gateway_client(provider: str) -> Any:
)
# OpenAI-compatible (openai, azure, together, groq, fireworks, etc.)
- from openai import AsyncOpenAI
-
return AsyncOpenAI(api_key=settings.api_key, base_url=settings.hud_gateway_url)
+
+
+def _fetch_gateway_models() -> list[GatewayModelInfo]:
+ """Fetch available models from HUD API."""
+ if not settings.api_key:
+ return []
+
+ try:
+ resp = httpx.get(
+ f"{settings.hud_api_url}/models/",
+ headers={"Authorization": f"Bearer {settings.api_key}"},
+ timeout=10.0,
+ )
+ resp.raise_for_status()
+ payload: object = resp.json()
+ if not isinstance(payload, dict) or "models" not in payload:
+ return []
+ return GatewayModelsResponse.model_validate(payload).models
+ except Exception:
+ return []
+
+
+def create_agent(model: str, **kwargs: Any) -> MCPAgent[Any]:
+ """Create an agent routed through the HUD gateway.
+
+ For direct API access with provider API keys, instantiate the agent classes directly.
+ """
+ agent_type = next((candidate for candidate in AgentType if candidate.value == model), None)
+ if agent_type is not None:
+ model_id = model
+ provider_name = agent_type.gateway_provider
+ else:
+ for gateway_model in _fetch_gateway_models():
+ if model in (
+ gateway_model.id,
+ gateway_model.name,
+ gateway_model.model_name,
+ ):
+ agent_str = (
+ gateway_model.sdk_agent_type or gateway_model.provider.default_sdk_agent_type
+ )
+ if agent_str == "operator":
+ raise ValueError(
+ "Operator agent is no longer supported; use openai with a supported "
+ "OpenAI computer model."
+ )
+ if agent_str == "gemini_cua":
+ raise ValueError(
+ "Gemini CUA agent is no longer supported; use gemini with a supported "
+ "Gemini computer-use model."
+ )
+ if not isinstance(agent_str, str):
+ raise ValueError(f"Model '{model}' has invalid agent type metadata")
+
+ agent_type = AgentType(agent_str)
+ model_id = gateway_model.model_name or model
+ provider_name = gateway_model.provider.name or "openai"
+ break
+ else:
+ raise ValueError(f"Model '{model}' not found")
+
+ client = build_gateway_client(provider_name)
+ kwargs.setdefault("model", model_id)
+ if agent_type == AgentType.OPENAI_COMPATIBLE:
+ kwargs.setdefault("openai_client", client)
+ else:
+ kwargs.setdefault("model_client", client)
+ kwargs.setdefault("validate_api_key", False)
+
+ return agent_type.cls.create(**kwargs)
diff --git a/hud/agents/gemini/agent.py b/hud/agents/gemini/agent.py
index 9ae2d3c2a..d4f83480f 100644
--- a/hud/agents/gemini/agent.py
+++ b/hud/agents/gemini/agent.py
@@ -2,40 +2,30 @@
from __future__ import annotations
+import base64
import logging
-from typing import Any, ClassVar, cast
+from functools import cached_property
+from typing import Any, cast
import mcp.types as types
from google import genai
from google.genai import types as genai_types
+from hud.agents import gateway
from hud.agents.base import MCPAgent
-from hud.agents.tools import (
- EnvironmentCapability,
- call_agent_tools,
- capabilities_metadata_from_context,
- discover_environment_capabilities,
- select_hosted_tools,
-)
-from hud.agents.types import GeminiConfig, GeminiCreateParams
+from hud.agents.types import GeminiConfig
from hud.settings import settings
-from hud.tools.computer import computer_settings
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-from hud.utils.hud_console import HUDConsole
+from hud.tools.types import Citation
+from hud.types import AgentResponse
from hud.utils.types import with_signature
-from .tools import (
- GeminiComputerTool,
- GeminiHostedTool,
- GeminiTool,
- gemini_tools,
- normalize_gemini_computer_use_args,
-)
+from .settings import gemini_agent_settings
+from .tools import GeminiAgentTools
logger = logging.getLogger(__name__)
-class GeminiAgent(MCPAgent):
+class GeminiAgent(MCPAgent[genai_types.Content]):
"""
Gemini agent that uses MCP servers for tool execution.
@@ -43,38 +33,25 @@ class GeminiAgent(MCPAgent):
tools through MCP servers instead of direct implementation.
"""
- metadata: ClassVar[dict[str, Any] | None] = None
- config_cls: ClassVar[type[BaseAgentConfig]] = GeminiConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for Gemini."""
- return AgentType.GEMINI
-
- @with_signature(GeminiCreateParams)
+ @with_signature(GeminiConfig)
@classmethod
- def create(cls, **kwargs: Any) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride]
- return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
+ def create(cls, **kwargs: object) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride]
+ return cls(GeminiConfig.model_validate(kwargs))
- def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> None:
- super().__init__(params, **kwargs)
+ def __init__(self, config: GeminiConfig | None = None) -> None:
+ config = config or GeminiConfig()
+ super().__init__(config)
self.config: GeminiConfig
model_client = self.config.model_client
if model_client is None:
if settings.api_key:
- from hud.agents.gateway import build_gateway_client
-
- model_client = build_gateway_client("gemini")
+ model_client = gateway.build_gateway_client("gemini")
elif settings.gemini_api_key:
model_client = genai.Client(api_key=settings.gemini_api_key)
if self.config.validate_api_key:
try:
- list(
- model_client.models.list(
- config=genai_types.ListModelsConfig(page_size=1)
- )
- )
+ next(iter(model_client.models.list()), None)
except Exception as e:
raise ValueError(f"Gemini API key is invalid: {e}") from e
else:
@@ -87,90 +64,76 @@ def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> N
" access"
)
- self.gemini_client: genai.Client = model_client
+ self.gemini_client: genai.Client = cast("genai.Client", model_client)
self.temperature = self.config.temperature
self.top_p = self.config.top_p
self.top_k = self.config.top_k
self.max_output_tokens = self.config.max_output_tokens
self.thinking_level = self.config.thinking_level
self.include_thoughts = self.config.include_thoughts
- self.hud_console = HUDConsole(logger=logger)
- # Track mapping from Gemini tool names to MCP tool names
- self._gemini_to_mcp_tool_map: dict[str, str] = {}
- self._computer_tool_name: str | None = None
- self._gemini_native_tools: dict[str, GeminiTool] = {}
- self._environment_capabilities: dict[str, EnvironmentCapability] = {}
self.excluded_predefined_functions = list(self.config.excluded_predefined_functions)
self.max_recent_turn_with_screenshots = (
- computer_settings.GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS
- )
- self.gemini_tools: genai_types.ToolListUnion = []
-
- def _on_tools_ready(self) -> None:
- """Build Gemini-specific tool mappings after tools are discovered."""
- self._convert_tools_for_gemini()
-
- def _discover_environment_capabilities(
- self, tools: list[types.Tool]
- ) -> dict[str, EnvironmentCapability]:
- return discover_environment_capabilities(
- tools,
- env_metadata=capabilities_metadata_from_context(self.ctx),
- name_fallbacks=gemini_tools.name_fallbacks,
+ gemini_agent_settings.MAX_RECENT_TURN_WITH_SCREENSHOTS
)
- async def get_system_messages(self) -> list[genai_types.Content]:
- """No system messages for Gemini because applied in get_response"""
- return []
-
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]:
- """Format messages for Gemini."""
- # Convert MCP content types to Gemini content types
- gemini_parts: list[genai_types.Part] = []
-
- for block in blocks:
- if isinstance(block, types.TextContent):
- gemini_parts.append(genai_types.Part(text=block.text))
- elif isinstance(block, types.ImageContent):
- # Convert MCP ImageContent to Gemini format
- # Need to decode base64 string to bytes
- import base64
-
- image_bytes = base64.b64decode(block.data)
- gemini_parts.append(
- genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType)
- )
- else:
- # For other types, try to handle but log a warning
- self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
+ @cached_property
+ def tools(self) -> GeminiAgentTools:
+ return GeminiAgentTools(
+ excluded_predefined_functions=self.excluded_predefined_functions,
+ )
- return [genai_types.Content(role="user", parts=gemini_parts)]
+ async def format_messages(
+ self, messages: list[types.PromptMessage]
+ ) -> list[genai_types.Content]:
+ """Format MCP prompt messages for Gemini."""
+ return [
+ genai_types.Content(
+ role="model" if str(message.role) == "assistant" else str(message.role),
+ parts=[_format_content(message.content)],
+ )
+ for message in messages
+ ]
- async def get_response(self, messages: list[genai_types.Content]) -> InferenceResult:
+ async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
"""Get response from Gemini including any tool calls."""
- self._remove_old_screenshots(messages)
- tools = self.gemini_tools
+ # Drop screenshots from older computer tool responses to keep context small.
+ screenshot_turns: list[list[genai_types.FunctionResponse]] = []
+ for content in reversed(messages):
+ if content.role != "user":
+ continue
- citations_enabled = bool(
- getattr(self.ctx, "scenario_enable_citations", False) if self.ctx else False
- )
- if citations_enabled and not self._has_google_search_tool():
+ turn_responses: list[genai_types.FunctionResponse] = []
+ for part in content.parts or []:
+ function_response = part.function_response
+ if (
+ function_response is not None
+ and function_response.parts
+ and function_response.name in self.tools.predefined_computer_functions
+ ):
+ turn_responses.append(function_response)
+
+ if turn_responses:
+ screenshot_turns.append(turn_responses)
+
+ for old_turn in screenshot_turns[self.max_recent_turn_with_screenshots :]:
+ for function_response in old_turn:
+ function_response.parts = None
+
+ # Configure Gemini generation options.
+ tools = cast("genai_types.ToolListUnion", self.tools.params)
+ if self.enable_citations and not any(tool.google_search for tool in self.tools.params):
tools = [*list(tools), genai_types.Tool(google_search=genai_types.GoogleSearch())]
thinking_config = None
if self.thinking_level is not None or self.include_thoughts:
- thinking_level = (
- genai_types.ThinkingLevel(self.thinking_level.upper())
- if self.thinking_level is not None
- else None
- )
thinking_config = genai_types.ThinkingConfig(
- thinking_level=thinking_level,
+ thinking_level=genai_types.ThinkingLevel(self.thinking_level.upper())
+ if self.thinking_level is not None
+ else None,
include_thoughts=self.include_thoughts,
)
- # Build generate content config
generate_config = genai_types.GenerateContentConfig(
temperature=self.temperature,
top_p=self.top_p,
@@ -181,396 +144,120 @@ async def get_response(self, messages: list[genai_types.Content]) -> InferenceRe
thinking_config=thinking_config,
)
- # Use async API to avoid blocking the event loop
- response = await self.gemini_client.aio.models.generate_content(
+ api_response = await self.gemini_client.aio.models.generate_content(
model=self.config.model,
contents=cast("Any", messages),
config=generate_config,
)
-
- # Append assistant response (including any function_call) so that
- # subsequent FunctionResponse messages correspond to a prior FunctionCall
- if response.candidates and len(response.candidates) > 0 and response.candidates[0].content:
- messages.append(response.candidates[0].content)
-
- # Process response
- result = InferenceResult(content="", tool_calls=[], done=True)
- collected_tool_calls: list[MCPToolCall] = []
-
- if not response.candidates:
- detail_parts = []
- for attr in ("prompt_feedback", "usage_metadata"):
- value = getattr(response, attr, None)
- if value is None:
- continue
- if hasattr(value, "model_dump_json"):
- value_repr = value.model_dump_json()
- elif hasattr(value, "model_dump"):
- value_repr = repr(value.model_dump())
- else:
- value_repr = repr(value)
- detail_parts.append(f"{attr}={value_repr}")
+ if not api_response.candidates:
+ detail_parts: list[str] = []
+ if api_response.prompt_feedback is not None:
+ detail_parts.append(
+ f"prompt_feedback={api_response.prompt_feedback.model_dump_json()}"
+ )
+ if api_response.usage_metadata is not None:
+ detail_parts.append(
+ f"usage_metadata={api_response.usage_metadata.model_dump_json()}"
+ )
details = "; ".join(detail_parts) if detail_parts else "no response metadata"
raise RuntimeError(
f"Gemini response returned no candidates for model {self.config.model}. {details}"
)
- candidate = response.candidates[0]
-
- # Extract text content and function calls
- text_content = ""
- thinking_content = ""
-
- if candidate.content and candidate.content.parts:
- for part in candidate.content.parts:
- if part.function_call:
- tool_call = self._extract_tool_call(part)
- if tool_call is not None:
- collected_tool_calls.append(tool_call)
- elif part.thought is True and part.text:
- if thinking_content:
- thinking_content += "\n"
- thinking_content += part.text
- elif part.text:
- text_content += part.text
-
- # Assign collected tool calls and mark done status
- if collected_tool_calls:
- result.tool_calls = collected_tool_calls
- result.done = False
-
- result.content = text_content
- if thinking_content:
- result.reasoning = thinking_content
-
- # Extract grounding citations from groundingMetadata
- grounding_meta = getattr(candidate, "grounding_metadata", None)
- if grounding_meta:
- citations: list[dict[str, Any]] = []
-
- # Build a lookup from chunk index → source info
- chunks = getattr(grounding_meta, "grounding_chunks", None) or []
- chunk_sources: list[dict[str, Any]] = []
- for chunk in chunks:
- web = getattr(chunk, "web", None)
- if web:
- chunk_sources.append(
- {
- "source": getattr(web, "uri", "") or "",
- "title": getattr(web, "title", None),
- }
- )
- else:
- chunk_sources.append({"source": "", "title": None})
-
- # Walk groundingSupports for text-segment anchoring
- supports = getattr(grounding_meta, "grounding_supports", None) or []
- seen_chunk_indices: set[int] = set()
- for support in supports:
- segment = getattr(support, "segment", None)
- support_chunk_indices = getattr(support, "grounding_chunk_indices", None) or []
- segment_text = getattr(segment, "text", "") or "" if segment else ""
- start_idx = getattr(segment, "start_index", None) if segment else None
- end_idx = getattr(segment, "end_index", None) if segment else None
-
- for idx in support_chunk_indices:
- seen_chunk_indices.add(idx)
- source_info = chunk_sources[idx] if idx < len(chunk_sources) else {}
- citations.append(
- {
- "type": "grounding",
- "text": segment_text,
- "source": source_info.get("source", ""),
- "title": source_info.get("title"),
- "start_index": start_idx,
- "end_index": end_idx,
- }
- )
-
- # Include any chunks not referenced by a support entry
- for idx, src in enumerate(chunk_sources):
- if idx not in seen_chunk_indices and src.get("source"):
- citations.append(
- {
- "type": "grounding",
- "text": "",
- "source": src["source"],
- "title": src.get("title"),
- }
- )
-
- result.citations = citations
+ candidate = api_response.candidates[0]
- return result
-
- def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None:
- """Extract an MCPToolCall from a function call part.
+ # Append assistant response (including any function_call) so that
+ # subsequent FunctionResponse messages correspond to a prior FunctionCall
+ content = candidate.content
+ if content is not None:
+ messages.append(content)
+
+ # Normalize text, thoughts, tool calls, and citations.
+ result = AgentResponse(content="", tool_calls=[], done=True)
+ text_parts: list[str] = []
+ thought_parts: list[str] = []
+
+ parts = []
+ if content is not None:
+ parts = content.parts or []
+ for part in parts:
+ function_call = part.function_call
+ if function_call is not None:
+ result.tool_calls.append(self.tools.tool_call(function_call))
+ result.done = False
+ continue
- Subclasses can override to customize tool call extraction (e.g., normalizing
- computer use calls to a different schema).
- """
- if not part.function_call:
- return None
+ if not part.text:
+ continue
- func_name = part.function_call.name or ""
- raw_args = dict(part.function_call.args) if part.function_call.args else {}
- mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name)
+ if part.thought is True:
+ thought_parts.append(part.text)
+ else:
+ text_parts.append(part.text)
- if mcp_tool_name:
- return MCPToolCall(
- name=mcp_tool_name,
- arguments=raw_args,
- )
+ result.content = "".join(text_parts)
+ if thought_parts:
+ result.reasoning = "\n".join(thought_parts)
- if self._computer_tool_name and func_name in gemini_tools.predefined_computer_functions:
- return MCPToolCall(
- name=self._computer_tool_name,
- arguments=normalize_gemini_computer_use_args(func_name, raw_args),
- gemini_name=func_name, # type: ignore[arg-type]
- )
+ grounding_meta = candidate.grounding_metadata
+ if grounding_meta is not None:
+ # TODO: Also normalize candidate.citation_metadata for URL-context citation spans.
+ result.citations = [
+ citation.model_dump(exclude={"provider_data"})
+ for citation in _grounding_citations(grounding_meta)
+ ]
- if func_name in self._gemini_native_tools:
- return MCPToolCall(
- name=func_name,
- arguments=raw_args,
- )
+ return result
- return MCPToolCall(
- name=func_name,
- arguments=raw_args,
- )
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[genai_types.Content]:
- """Format tool results into Gemini messages."""
- # Process each tool result
- function_responses = []
-
- for tool_call, result in zip(tool_calls, tool_results, strict=True):
- # Get the Gemini function name from metadata
- gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
-
- # Convert MCP tool results to Gemini format
- response_dict: dict[str, Any] = {}
- is_computer_call = (
- self._computer_tool_name is not None and tool_call.name == self._computer_tool_name
+def _format_content(
+ content: types.ContentBlock,
+) -> genai_types.Part:
+ match content:
+ case types.TextContent(text=text):
+ return genai_types.Part(text=text)
+ case types.ImageContent(data=data, mimeType=mime_type):
+ return genai_types.Part.from_bytes(
+ data=base64.b64decode(data),
+ mime_type=mime_type or "image/png",
)
-
- if result.isError:
- # Extract error message from content
- error_msg = "Tool execution failed"
- for content in result.content:
- if isinstance(content, types.TextContent):
- if content.text.startswith("__URL__:"):
- continue
- error_msg = content.text
- break
- response_dict["error"] = error_msg
- if is_computer_call:
- response_dict["url"] = self._extract_url(result) or "about:blank"
- else:
- # Process success content
- response_dict["success"] = True
-
- screenshot_parts: list[genai_types.FunctionResponsePart] = []
- if is_computer_call:
- url = self._extract_url(result)
- for content in result.content:
- if isinstance(content, types.ImageContent):
- import base64
-
- image_bytes = base64.b64decode(content.data)
- screenshot_parts.append(
- genai_types.FunctionResponsePart(
- inline_data=genai_types.FunctionResponseBlob(
- mime_type=content.mimeType or "image/png",
- data=image_bytes,
- )
- )
- )
- elif isinstance(content, types.TextContent) and content.text.startswith(
- "__GEMINI_SAFETY_BLOCKED__:"
- ):
- response_dict.pop("success", None)
- response_dict["blocked"] = True
- response_dict["reason"] = content.text.replace(
- "__GEMINI_SAFETY_BLOCKED__:", "", 1
- )
-
- response_dict["url"] = url or "about:blank"
- safety_decision = (
- tool_call.arguments.get("safety_decision") if tool_call.arguments else None
+ case _:
+ raise ValueError(f"Unknown content block type: {type(content)}")
+
+
+def _grounding_citations(
+ grounding_meta: genai_types.GroundingMetadata,
+) -> list[Citation]:
+ citations: list[Citation] = []
+ chunk_sources: list[tuple[str, str | None]] = []
+ for chunk in grounding_meta.grounding_chunks or []:
+ if chunk.web is None:
+ chunk_sources.append(("", None))
+ else:
+ chunk_sources.append((chunk.web.uri or "", chunk.web.title))
+
+ seen_chunk_indices: set[int] = set()
+ for support in grounding_meta.grounding_supports or []:
+ segment = support.segment
+ segment_text = segment.text or "" if segment else ""
+ start_idx = segment.start_index if segment else None
+ end_idx = segment.end_index if segment else None
+
+ for idx in support.grounding_chunk_indices or []:
+ seen_chunk_indices.add(idx)
+ source, title = chunk_sources[idx] if 0 <= idx < len(chunk_sources) else ("", None)
+ citations.append(
+ Citation(
+ type="grounding",
+ text=segment_text,
+ source=source,
+ title=title,
+ start_index=start_idx,
+ end_index=end_idx,
)
- if safety_decision and not result.isError and not response_dict.get("blocked"):
- response_dict["safety_acknowledgement"] = True
- else:
- # Add text content to response
- for content in result.content:
- if isinstance(content, types.TextContent):
- response_dict["output"] = content.text
- break
-
- # Create function response
- function_response = genai_types.FunctionResponse(
- name=gemini_name,
- response=response_dict,
- parts=screenshot_parts if screenshot_parts else None,
- )
- function_responses.append(function_response)
-
- # Return as a user message containing all function responses
- return [
- genai_types.Content(
- role="user",
- parts=[genai_types.Part(function_response=fr) for fr in function_responses],
- )
- ]
-
- @staticmethod
- def _extract_url(result: MCPToolResult) -> str | None:
- for content in result.content:
- if isinstance(content, types.TextContent) and content.text.startswith("__URL__:"):
- return content.text.replace("__URL__:", "", 1)
- return None
-
- async def call_tools(
- self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
- ) -> list[MCPToolResult]:
- """Route Gemini-owned native tool calls through provider translators."""
- return await call_agent_tools(self, self._gemini_native_tools, tool_call)
-
- def _map_role(self, role: str) -> str:
- """Gemini uses 'model' instead of 'assistant' for non-user turns."""
- if role == "assistant":
- return "model"
- return role
-
- async def create_user_message(self, text: str) -> genai_types.Content:
- """Create a user message in Gemini's format."""
- return genai_types.Content(role="user", parts=[genai_types.Part(text=text)])
-
- def _has_google_search_tool(self) -> bool:
- """Check if google_search is already in the tool list."""
- return any(getattr(tool, "google_search", None) is not None for tool in self.gemini_tools)
-
- def _convert_tools_for_gemini(self) -> None:
- """Convert MCP tools to Gemini tool format."""
- self._gemini_to_mcp_tool_map = {}
- self._computer_tool_name = None
- self._gemini_native_tools = {}
- self.gemini_tools = []
-
- categorized = self._categorized_tools
-
- capabilities = self._discover_environment_capabilities(self.get_available_tools())
- self._environment_capabilities = capabilities
- provider_backing_tools: set[str] = set()
-
- for capability in capabilities.values():
- if capability.name not in gemini_tools.capabilities:
- continue
- for gemini_tool in gemini_tools.tools_for_capability(capability, self.model):
- provider_backing_tools.add(gemini_tool.env_tool_name)
- if isinstance(gemini_tool, GeminiComputerTool):
- self._computer_tool_name = gemini_tool.name
- self._gemini_native_tools[gemini_tool.name] = gemini_tool
- gemini_tool.excluded_predefined_functions = (
- self._computer_use_excluded_function_names(gemini_tool.env_tool_name)
- )
- self.gemini_tools.append(gemini_tool.to_params())
- continue
-
- self._gemini_native_tools[gemini_tool.name] = gemini_tool
- self.gemini_tools.append(gemini_tool.to_params())
-
- configured_hosted = select_hosted_tools(
- self.config.hosted_tools,
- tool_type=GeminiHostedTool,
- model=self.model,
- )
- self.gemini_tools.extend(tool.to_params() for tool in configured_hosted)
-
- # Process generic function tools
- for tool in categorized.generic:
- if tool.name in provider_backing_tools:
- continue
- gemini_tool = self._to_gemini_tool(tool)
- if gemini_tool:
- self._gemini_to_mcp_tool_map[tool.name] = tool.name
- self.gemini_tools.append(gemini_tool)
-
- # Log actual tools being used
- tool_names = sorted(
- {
- *self._gemini_to_mcp_tool_map.keys(),
- *self._gemini_native_tools.keys(),
- }
- )
- self.console.info(
- f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}"
- )
-
- def _computer_use_excluded_function_names(self, computer_tool_name: str) -> list[str]:
- excluded = [
- *self.excluded_predefined_functions,
- *self._colliding_predefined_function_names(computer_tool_name),
- ]
- return sorted(set(excluded))
-
- def _colliding_predefined_function_names(self, computer_tool_name: str) -> list[str]:
- """Exclude predefined computer actions shadowed by generic MCP tools."""
- generic_names = {
- tool.name for tool in self._categorized_tools.generic if tool.name != computer_tool_name
- }
- return sorted(set(gemini_tools.predefined_computer_functions) & generic_names)
-
- def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None:
- """Drop older Gemini Computer Use screenshots to keep context growth bounded."""
- if self._computer_tool_name is None:
- return
-
- turn_with_screenshots_found = 0
- for content in reversed(messages):
- if content.role != "user" or not content.parts:
- continue
-
- has_screenshot = any(
- part.function_response
- and part.function_response.parts
- and part.function_response.name in gemini_tools.predefined_computer_functions
- for part in content.parts
)
- if not has_screenshot:
- continue
- turn_with_screenshots_found += 1
- if turn_with_screenshots_found <= self.max_recent_turn_with_screenshots:
- continue
-
- for part in content.parts:
- if (
- part.function_response
- and part.function_response.parts
- and part.function_response.name in gemini_tools.predefined_computer_functions
- ):
- part.function_response.parts = None
-
- def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None:
- """Convert a single MCP tool to Gemini function tool format.
-
- Args:
- tool: MCP tool to convert
-
- Returns:
- Gemini Tool with function declaration
- """
- if tool.description is None or tool.inputSchema is None:
- raise ValueError(f"MCP tool {tool.name} requires both a description and inputSchema.")
-
- function_decl = genai_types.FunctionDeclaration(
- name=tool.name,
- description=tool.description,
- parameters_json_schema=tool.inputSchema,
- )
- return genai_types.Tool(function_declarations=[function_decl])
+ for idx, (source, title) in enumerate(chunk_sources):
+ if idx not in seen_chunk_indices and source:
+ citations.append(Citation(type="grounding", text="", source=source, title=title))
+ return citations
diff --git a/hud/agents/gemini/settings.py b/hud/agents/gemini/settings.py
new file mode 100644
index 000000000..2a7c89b6e
--- /dev/null
+++ b/hud/agents/gemini/settings.py
@@ -0,0 +1,21 @@
+"""Gemini agent settings."""
+
+from __future__ import annotations
+
+from pydantic import Field
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+
+class GeminiAgentSettings(BaseSettings):
+ """Gemini provider defaults owned by the agent."""
+
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="allow")
+
+ MAX_RECENT_TURN_WITH_SCREENSHOTS: int = Field(
+ default=3,
+ description="Maximum number of recent turns to keep screenshots for in Gemini agent",
+ validation_alias="GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS",
+ )
+
+
+gemini_agent_settings = GeminiAgentSettings()
diff --git a/hud/agents/gemini/tools/__init__.py b/hud/agents/gemini/tools/__init__.py
index 33c31d9ea..ba9583915 100644
--- a/hud/agents/gemini/tools/__init__.py
+++ b/hud/agents/gemini/tools/__init__.py
@@ -2,30 +2,20 @@
from __future__ import annotations
-from dataclasses import dataclass, field
-
-from hud.agents.tools import AgentToolRegistry
-
-from .base import GeminiTool
-from .coding import (
- GEMINI_EDIT_SPEC,
- GEMINI_SHELL_SPEC,
- GEMINI_WRITE_SPEC,
- GeminiEditTool,
- GeminiShellTool,
- GeminiWriteTool,
-)
+from typing import TYPE_CHECKING, ClassVar
+
+from google.genai import types as genai_types
+
+from hud.agents.tools import AgentTool, AgentTools
+from hud.types import MCPToolCall
+
+from .base import GeminiFunctionTool
+from .coding import GeminiEditTool, GeminiShellTool, GeminiWriteTool
from .computer import (
- GEMINI_COMPUTER_SPEC,
PREDEFINED_COMPUTER_USE_FUNCTIONS,
GeminiComputerTool,
- normalize_gemini_computer_use_args,
)
from .filesystem import (
- GEMINI_GLOB_SPEC,
- GEMINI_LIST_SPEC,
- GEMINI_READ_SPEC,
- GEMINI_SEARCH_SPEC,
GeminiGlobTool,
GeminiListTool,
GeminiReadTool,
@@ -37,14 +27,20 @@
GeminiHostedTool,
GeminiUrlContextTool,
)
-from .memory import GEMINI_MEMORY_SPEC, GeminiMemoryTool
+from .memory import GeminiMemoryTool
+
+if TYPE_CHECKING:
+ from collections.abc import Mapping
+
+ import mcp.types as types
+
+ from hud.agents.tools import ToolMetadata
-@dataclass(frozen=True)
-class GeminiToolRegistry(AgentToolRegistry[GeminiTool]):
- """Registry for Gemini harness tools."""
+class GeminiAgentTools(AgentTools[AgentTool[genai_types.Tool], genai_types.Tool]):
+ """Prepared Gemini tool state for a run."""
- tool_classes: tuple[type[GeminiTool], ...] = (
+ native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = (
GeminiComputerTool,
GeminiShellTool,
GeminiEditTool,
@@ -55,52 +51,79 @@ class GeminiToolRegistry(AgentToolRegistry[GeminiTool]):
GeminiListTool,
GeminiMemoryTool,
)
- name_fallbacks: dict[str, tuple[str, ...]] = field(
- default_factory=lambda: {
- "computer": ("computer", "gemini_computer", "computer_gemini"),
- "shell": ("bash",),
- "editor": ("edit",),
- "filesystem": ("read", "grep", "glob", "list"),
- "memory": ("memory",),
- }
- )
+ function_tool_class = GeminiFunctionTool
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {
+ "computer": ("computer", "gemini_computer", "computer_gemini"),
+ "shell": ("bash",),
+ "editor": ("edit",),
+ "filesystem": ("read", "grep", "glob", "list"),
+ "memory": ("memory",),
+ }
+
+ def __init__(self, *, excluded_predefined_functions: list[str] | None = None) -> None:
+ super().__init__()
+ self.excluded_predefined_functions = list(excluded_predefined_functions or [])
@property
- def api_types(self) -> frozenset[str]:
- return frozenset(cls.name for cls in self.tool_classes)
+ def computer_tool_name(self) -> str | None:
+ return "computer_use" if "computer_use" in self else None
@property
def predefined_computer_functions(self) -> frozenset[str]:
return frozenset(PREDEFINED_COMPUTER_USE_FUNCTIONS)
+ def tool_call(self, function_call: genai_types.FunctionCall) -> MCPToolCall:
+ name = function_call.name or ""
+ arguments = dict(function_call.args) if function_call.args else {}
+
+ if mcp_tool_name := self.name_map.get(name):
+ return MCPToolCall(name=mcp_tool_name, arguments=arguments)
+
+ if self.computer_tool_name and name in self.predefined_computer_functions:
+ computer_tool = self.get(self.computer_tool_name)
+ if isinstance(computer_tool, GeminiComputerTool):
+ return computer_tool.tool_call(name, arguments)
+
+ return MCPToolCall(name=name, arguments=arguments)
+
+ def select_tools(
+ self,
+ tools: list[types.Tool],
+ model: str,
+ *,
+ tool_metadata: ToolMetadata | None = None,
+ excluded_predefined_functions: list[str] | None = None,
+ ) -> tuple[list[AgentTool[genai_types.Tool]], list[types.Tool]]:
+ provider_tools, user_tools = super().select_tools(
+ tools,
+ model,
+ tool_metadata=tool_metadata,
+ )
+ user_tool_names = {tool.name for tool in user_tools}
+ configured_exclusions = (
+ excluded_predefined_functions
+ if excluded_predefined_functions is not None
+ else self.excluded_predefined_functions
+ )
+ colliding_exclusions = sorted(self.predefined_computer_functions & user_tool_names)
+ exclusions = sorted({*configured_exclusions, *colliding_exclusions})
+ if not exclusions:
+ return provider_tools, user_tools
+ return (
+ [
+ tool.with_excluded_predefined_functions(exclusions)
+ if isinstance(tool, GeminiComputerTool)
+ else tool
+ for tool in provider_tools
+ ],
+ user_tools,
+ )
-gemini_tools = GeminiToolRegistry()
__all__ = [
- "GEMINI_COMPUTER_SPEC",
- "GEMINI_EDIT_SPEC",
- "GEMINI_GLOB_SPEC",
- "GEMINI_LIST_SPEC",
- "GEMINI_MEMORY_SPEC",
- "GEMINI_READ_SPEC",
- "GEMINI_SEARCH_SPEC",
- "GEMINI_SHELL_SPEC",
- "GEMINI_WRITE_SPEC",
+ "GeminiAgentTools",
"GeminiCodeExecutionTool",
- "GeminiComputerTool",
- "GeminiEditTool",
- "GeminiGlobTool",
"GeminiGoogleSearchTool",
"GeminiHostedTool",
- "GeminiListTool",
- "GeminiMemoryTool",
- "GeminiReadTool",
- "GeminiSearchTool",
- "GeminiShellTool",
- "GeminiTool",
- "GeminiToolRegistry",
"GeminiUrlContextTool",
- "GeminiWriteTool",
- "gemini_tools",
- "normalize_gemini_computer_use_args",
]
diff --git a/hud/agents/gemini/tools/base.py b/hud/agents/gemini/tools/base.py
index 6d8612ca8..a52081d4a 100644
--- a/hud/agents/gemini/tools/base.py
+++ b/hud/agents/gemini/tools/base.py
@@ -2,20 +2,20 @@
from __future__ import annotations
-from typing import Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar
+import mcp.types as types
from google.genai import types as genai_types
-from hud.agents.tools import AgentTool, AgentToolSpec, CallTool, call_tool
-
-GeminiToolSpec = AgentToolSpec
+from hud.agents.tools import AgentTool, AgentToolSpec
+if TYPE_CHECKING:
+ from hud.types import MCPToolCall, MCPToolResult
-class GeminiTool(AgentTool[Any]):
- """Gemini provider tool backed by an environment tool."""
+GeminiToolSpec = AgentToolSpec
-class GeminiFunctionTool(GeminiTool):
+class GeminiTool(AgentTool[genai_types.Tool]):
"""Gemini function declaration backed by an environment tool."""
description: ClassVar[str]
@@ -25,12 +25,77 @@ def to_params(self) -> genai_types.Tool:
return genai_types.Tool(
function_declarations=[
genai_types.FunctionDeclaration(
- name=self.name,
+ name=self.provider_name,
description=self.description,
parameters_json_schema=self.parameters,
)
]
)
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> genai_types.Content:
+ text = next(
+ (content.text for content in result.content if isinstance(content, types.TextContent)),
+ None,
+ )
+ response: dict[str, Any] = (
+ {"error": text or "Tool execution failed"} if result.isError else {"success": True}
+ )
+ if text is not None and not result.isError:
+ response["output"] = text
+ return genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(
+ function_response=genai_types.FunctionResponse(
+ name=call.provider_name or call.name,
+ response=response,
+ )
+ )
+ ],
+ )
+
+
+class GeminiFunctionTool(GeminiTool):
+ """Regular environment tool exposed as a Gemini function declaration."""
+
+ name = "function"
+ capability = "function"
+
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
+ description: str,
+ parameters: dict[str, Any],
+ ) -> None:
+ super().__init__(
+ env_tool_name=env_tool_name,
+ spec=GeminiToolSpec(api_type="function", api_name=env_tool_name),
+ )
+ self._description = description
+ self._parameters = parameters
+
+ @classmethod
+ def from_tool(cls, tool: types.Tool) -> GeminiFunctionTool:
+ if tool.description is None:
+ raise ValueError(f"MCP tool {tool.name} requires a description.")
+ return cls(
+ env_tool_name=tool.name,
+ description=tool.description,
+ parameters=tool.inputSchema,
+ )
+
+ @property
+ def provider_name(self) -> str:
+ return self.env_tool_name
-__all__ = ["CallTool", "GeminiFunctionTool", "GeminiTool", "GeminiToolSpec", "call_tool"]
+ def to_params(self) -> genai_types.Tool:
+ return genai_types.Tool(
+ function_declarations=[
+ genai_types.FunctionDeclaration(
+ name=self.provider_name,
+ description=self._description,
+ parameters_json_schema=self._parameters,
+ )
+ ]
+ )
diff --git a/hud/agents/gemini/tools/coding.py b/hud/agents/gemini/tools/coding.py
index 6817764f3..50a2eec1b 100644
--- a/hud/agents/gemini/tools/coding.py
+++ b/hud/agents/gemini/tools/coding.py
@@ -6,16 +6,17 @@
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
+ from hud.agents.tools.base import CallTool
from hud.types import MCPToolResult
-from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool
+from .base import GeminiTool, GeminiToolSpec
GEMINI_SHELL_SPEC = GeminiToolSpec(api_type="run_shell_command", api_name="run_shell_command")
GEMINI_EDIT_SPEC = GeminiToolSpec(api_type="replace", api_name="replace")
GEMINI_WRITE_SPEC = GeminiToolSpec(api_type="write_file", api_name="write_file")
-class GeminiShellTool(GeminiFunctionTool):
+class GeminiShellTool(GeminiTool):
"""Translate Gemini CLI shell calls into the generic bash env primitive."""
name = "run_shell_command"
@@ -39,17 +40,17 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_SHELL_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
command = arguments.get("command")
if not isinstance(command, str) or not command:
raise ValueError("command is required")
dir_path = arguments.get("dir_path")
if isinstance(dir_path, str) and dir_path:
command = f"cd {shlex.quote(dir_path)} && {command}"
- return await call_tool(caller, self.env_tool_name, {"command": command})
+ return await super().execute(call_tool, {"command": command})
-class GeminiEditTool(GeminiFunctionTool):
+class GeminiEditTool(GeminiTool):
"""Translate Gemini CLI replace calls into the generic edit env primitive."""
name = "replace"
@@ -74,19 +75,21 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_EDIT_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
file_path = _required_str(arguments, "file_path")
old_string = arguments.get("old_string")
new_string = arguments.get("new_string")
if old_string == "":
- return await call_tool(
- caller,
- self.env_tool_name,
- {"command": "create", "path": file_path, "file_text": new_string or ""},
+ return await super().execute(
+ call_tool,
+ {
+ "command": "create",
+ "path": file_path,
+ "file_text": new_string or "",
+ },
)
- return await call_tool(
- caller,
- self.env_tool_name,
+ return await super().execute(
+ call_tool,
{
"command": "replace",
"path": file_path,
@@ -96,7 +99,7 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR
)
-class GeminiWriteTool(GeminiFunctionTool):
+class GeminiWriteTool(GeminiTool):
"""Translate Gemini CLI write_file calls into the generic edit env primitive."""
name = "write_file"
@@ -116,10 +119,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_WRITE_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- return await call_tool(
- caller,
- self.env_tool_name,
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ return await super().execute(
+ call_tool,
{
"command": "write",
"path": _required_str(arguments, "file_path"),
@@ -133,13 +135,3 @@ def _required_str(arguments: dict[str, Any], key: str) -> str:
if not isinstance(value, str) or not value:
raise ValueError(f"{key} is required")
return value
-
-
-__all__ = [
- "GEMINI_EDIT_SPEC",
- "GEMINI_SHELL_SPEC",
- "GEMINI_WRITE_SPEC",
- "GeminiEditTool",
- "GeminiShellTool",
- "GeminiWriteTool",
-]
diff --git a/hud/agents/gemini/tools/computer.py b/hud/agents/gemini/tools/computer.py
index b52680e49..cf8684c68 100644
--- a/hud/agents/gemini/tools/computer.py
+++ b/hud/agents/gemini/tools/computer.py
@@ -2,18 +2,21 @@
from __future__ import annotations
+import base64
import platform
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
from google.genai import types as genai_types
from mcp.types import ImageContent, TextContent
-from hud.types import MCPToolResult
+from hud.agents.tools import AgentTool
+from hud.agents.tools.computer import computer_error_result, execute_computer_calls
+from hud.types import MCPToolCall, MCPToolResult
-from .base import CallTool, GeminiTool, GeminiToolSpec, call_tool
+from .base import GeminiToolSpec
if TYPE_CHECKING:
- from hud.agents.tools import EnvironmentCapability
+ from hud.agents.tools.base import CallTool
SUPPORTED_GEMINI_COMPUTER_USE_MODELS = (
"gemini-2.5-computer-use-preview-10-2025",
@@ -22,6 +25,7 @@
GEMINI_COORDINATE_SPACE = 1000
GEMINI_DRAG_INSET = 25
+IS_MAC = platform.system().lower() == "darwin"
PREDEFINED_COMPUTER_USE_FUNCTIONS = (
"open_web_browser",
@@ -38,6 +42,8 @@
"key_combination",
"drag_and_drop",
)
+GEMINI_URL_PREFIX = "__URL__:"
+GEMINI_SAFETY_BLOCKED_PREFIX = "__GEMINI_SAFETY_BLOCKED__:"
GEMINI_COMPUTER_SPEC = GeminiToolSpec(
api_type="computer_use",
@@ -46,51 +52,7 @@
)
-def normalize_gemini_computer_use_args(action: str, raw_args: dict[str, Any]) -> dict[str, Any]:
- """Normalize Gemini Computer Use function-call args to agent-tool args."""
- normalized_args: dict[str, Any] = {"action": action}
-
- coord = raw_args.get("coordinate") or raw_args.get("coordinates")
- if isinstance(coord, list | tuple) and len(coord) >= 2:
- try:
- normalized_args["x"] = int(coord[0])
- normalized_args["y"] = int(coord[1])
- except (TypeError, ValueError):
- pass
-
- dest = (
- raw_args.get("destination")
- or raw_args.get("destination_coordinate")
- or raw_args.get("destinationCoordinate")
- )
- if isinstance(dest, list | tuple) and len(dest) >= 2:
- try:
- normalized_args["destination_x"] = int(dest[0])
- normalized_args["destination_y"] = int(dest[1])
- except (TypeError, ValueError):
- pass
-
- for key in (
- "text",
- "press_enter",
- "clear_before_typing",
- "safety_decision",
- "direction",
- "magnitude",
- "url",
- "keys",
- "x",
- "y",
- "destination_x",
- "destination_y",
- ):
- if key in raw_args:
- normalized_args[key] = raw_args[key]
-
- return normalized_args
-
-
-class GeminiComputerTool(GeminiTool):
+class GeminiComputerTool(AgentTool[genai_types.Tool]):
"""Translate Gemini Computer Use calls into generic environment computer calls."""
name = "computer_use"
@@ -102,19 +64,24 @@ def default_spec(cls, model: str) -> GeminiToolSpec | None:
return GEMINI_COMPUTER_SPEC
return None
- @classmethod
- def from_capability(
- cls,
- capability: EnvironmentCapability,
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
spec: GeminiToolSpec,
- model: str,
- ) -> GeminiComputerTool:
- del model
- return cls(env_tool_name=capability.tool_name, spec=spec)
-
- def __init__(self, *, env_tool_name: str, spec: GeminiToolSpec) -> None:
+ excluded_predefined_functions: list[str] | None = None,
+ ) -> None:
super().__init__(env_tool_name=env_tool_name, spec=spec)
- self.excluded_predefined_functions: list[str] = []
+ self.excluded_predefined_functions = excluded_predefined_functions or []
+
+ def with_excluded_predefined_functions(
+ self, excluded_predefined_functions: list[str]
+ ) -> GeminiComputerTool:
+ return GeminiComputerTool(
+ env_tool_name=self.env_tool_name,
+ spec=self.spec,
+ excluded_predefined_functions=excluded_predefined_functions,
+ )
def to_params(self) -> genai_types.Tool:
return genai_types.Tool(
@@ -124,31 +91,100 @@ def to_params(self) -> genai_types.Tool:
)
)
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ def tool_call(self, function_name: str, raw_args: dict[str, Any]) -> MCPToolCall:
+ return MCPToolCall(
+ name=self.name,
+ arguments={"action": function_name, **raw_args},
+ provider_name=function_name,
+ )
+
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> genai_types.Content:
+ text = next(
+ (
+ content.text
+ for content in result.content
+ if isinstance(content, TextContent)
+ and not content.text.startswith(GEMINI_URL_PREFIX)
+ ),
+ None,
+ )
+ response: dict[str, Any] = (
+ {"error": text or "Tool execution failed"} if result.isError else {"success": True}
+ )
+ if text is not None and not result.isError:
+ response["output"] = text
+
+ url = None
+ parts: list[genai_types.FunctionResponsePart] = []
+ for content in result.content:
+ match content:
+ case ImageContent(data=data, mimeType=mime_type):
+ parts.append(
+ genai_types.FunctionResponsePart(
+ inline_data=genai_types.FunctionResponseBlob(
+ mime_type=mime_type or "image/png",
+ data=base64.b64decode(data),
+ )
+ )
+ )
+ case TextContent(text=text) if text.startswith(GEMINI_URL_PREFIX):
+ url = text.removeprefix(GEMINI_URL_PREFIX)
+ case TextContent(text=text) if text.startswith(GEMINI_SAFETY_BLOCKED_PREFIX):
+ response.pop("success", None)
+ response["blocked"] = True
+ response["reason"] = text.removeprefix(GEMINI_SAFETY_BLOCKED_PREFIX)
+ case _:
+ continue
+
+ response["url"] = url or "about:blank"
+ safety_decision = call.arguments.get("safety_decision") if call.arguments else None
+ if safety_decision and not result.isError and not response.get("blocked"):
+ response["safety_acknowledgement"] = True
+
+ return genai_types.Content(
+ role="user",
+ parts=[
+ genai_types.Part(
+ function_response=genai_types.FunctionResponse(
+ name=call.provider_name or call.name,
+ response=response,
+ parts=parts or None,
+ )
+ )
+ ],
+ )
+
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
action = arguments.get("action")
if not isinstance(action, str):
- return _error_result("action is required")
- if _requires_confirmation(arguments.get("safety_decision")):
- return _blocked_result(
- "Gemini Computer Use action requires user confirmation before execution."
+ return computer_error_result("action is required")
+ safety_decision = arguments.get("safety_decision")
+ if (
+ isinstance(safety_decision, dict)
+ and cast("dict[str, Any]", safety_decision).get("decision") == "require_confirmation"
+ ):
+ return MCPToolResult(
+ content=[
+ TextContent(
+ type="text",
+ text=(
+ f"{GEMINI_SAFETY_BLOCKED_PREFIX}"
+ "Gemini Computer Use action requires user confirmation before "
+ "execution."
+ ),
+ )
+ ],
+ isError=False,
)
- result = MCPToolResult(content=[], isError=False)
- for call in self._env_calls(action, arguments):
- result = await call_tool(caller, self.env_tool_name, call)
- if result.isError:
- return result
-
- if action != "open_web_browser" and not _has_image(result):
- screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"})
- if not screenshot.isError and screenshot.content:
- result = MCPToolResult(
- content=[*result.content, *screenshot.content],
- isError=result.isError,
- )
- return result
+ return await execute_computer_calls(
+ call_tool,
+ env_tool_name=self.env_tool_name,
+ calls=self._computer_actions(action, arguments),
+ ensure_screenshot=action != "open_web_browser",
+ )
- def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
+ def _computer_actions(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
if action == "open_web_browser":
return [{"action": "screenshot"}]
if action == "click_at":
@@ -165,7 +201,12 @@ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, A
]
)
if arguments.get("clear_before_typing", True):
- calls.extend(_clear_text_calls())
+ calls.extend(
+ [
+ {"action": "press", "keys": ["cmd", "a"] if IS_MAC else ["ctrl", "a"]},
+ {"action": "press", "keys": ["backspace" if IS_MAC else "delete"]},
+ ]
+ )
calls.append(
{
"action": "write",
@@ -175,133 +216,77 @@ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, A
)
return calls
if action in ("scroll_document", "scroll_at"):
- call = _scroll_call(arguments)
+ direction = arguments.get("direction")
+ magnitude = arguments.get("magnitude") or 800
+ if direction == "down":
+ call = {"action": "scroll", "scroll_x": None, "scroll_y": magnitude}
+ elif direction == "up":
+ call = {"action": "scroll", "scroll_x": None, "scroll_y": -magnitude}
+ elif direction == "right":
+ call = {"action": "scroll", "scroll_x": magnitude, "scroll_y": None}
+ elif direction == "left":
+ call = {"action": "scroll", "scroll_x": -magnitude, "scroll_y": None}
+ else:
+ raise ValueError("direction must be one of up, down, left, right")
if action == "scroll_at":
call.update({"x": arguments.get("x"), "y": arguments.get("y")})
return [call]
if action == "wait_5_seconds":
return [{"action": "wait", "time": 5000}]
if action == "go_back":
- return [{"action": "press", "keys": ["cmd", "["] if _is_mac() else ["alt", "left"]}]
+ return [{"action": "press", "keys": ["cmd", "["] if IS_MAC else ["alt", "left"]}]
if action == "go_forward":
- return [{"action": "press", "keys": ["cmd", "]"] if _is_mac() else ["alt", "right"]}]
+ return [{"action": "press", "keys": ["cmd", "]"] if IS_MAC else ["alt", "right"]}]
if action == "search":
target = arguments.get("url") or "https://www.google.com"
- return [*_address_bar_calls(), {"action": "write", "text": target, "enter_after": True}]
+ return [
+ {"action": "press", "keys": ["cmd", "l"] if IS_MAC else ["ctrl", "l"]},
+ {"action": "write", "text": target, "enter_after": True},
+ ]
if action == "navigate":
return [
- *_address_bar_calls(),
+ {"action": "press", "keys": ["cmd", "l"] if IS_MAC else ["ctrl", "l"]},
{"action": "write", "text": arguments.get("url"), "enter_after": True},
]
if action == "key_combination":
- return [{"action": "press", "keys": _normalize_key_combination(arguments.get("keys"))}]
+ keys = arguments.get("keys")
+ if not isinstance(keys, str):
+ raise ValueError("keys must be a '+'-separated string")
+ aliases = {
+ "control": "ctrl",
+ "cmd": "cmd",
+ "command": "cmd",
+ "meta": "cmd" if IS_MAC else "ctrl",
+ "return": "enter",
+ }
+ normalized_keys = [
+ aliases.get(key, key) for part in keys.split("+") if (key := part.strip().lower())
+ ]
+ return [{"action": "press", "keys": normalized_keys}]
if action == "drag_and_drop":
+ max_drag_coordinate = max(
+ GEMINI_COORDINATE_SPACE - GEMINI_DRAG_INSET,
+ GEMINI_DRAG_INSET,
+ )
+
+ def drag_coordinate(value: Any) -> Any:
+ if not isinstance(value, int | float) or not 0 <= value <= GEMINI_COORDINATE_SPACE:
+ return value
+ return min(max(int(value), GEMINI_DRAG_INSET), max_drag_coordinate)
+
return [
{
"action": "drag",
"path": [
{
- "x": _inset_drag_coordinate(arguments.get("x")),
- "y": _inset_drag_coordinate(arguments.get("y")),
+ "x": drag_coordinate(arguments.get("x")),
+ "y": drag_coordinate(arguments.get("y")),
},
{
- "x": _inset_drag_coordinate(arguments.get("destination_x")),
- "y": _inset_drag_coordinate(arguments.get("destination_y")),
+ "x": drag_coordinate(arguments.get("destination_x")),
+ "y": drag_coordinate(arguments.get("destination_y")),
},
],
}
]
raise ValueError(f"Unknown Gemini computer action: {action}")
-
-
-def _scroll_call(arguments: dict[str, Any]) -> dict[str, Any]:
- direction = arguments.get("direction")
- magnitude = arguments.get("magnitude") or 800
- if direction == "down":
- return {"action": "scroll", "scroll_x": None, "scroll_y": magnitude}
- if direction == "up":
- return {"action": "scroll", "scroll_x": None, "scroll_y": -magnitude}
- if direction == "right":
- return {"action": "scroll", "scroll_x": magnitude, "scroll_y": None}
- if direction == "left":
- return {"action": "scroll", "scroll_x": -magnitude, "scroll_y": None}
- raise ValueError("direction must be one of up, down, left, right")
-
-
-def _inset_drag_coordinate(value: Any) -> Any:
- """Keep Gemini normalized drag endpoints away from display edges."""
- if not isinstance(value, int | float) or not 0 <= value <= GEMINI_COORDINATE_SPACE:
- return value
- max_value = max(GEMINI_COORDINATE_SPACE - GEMINI_DRAG_INSET, GEMINI_DRAG_INSET)
- return min(max(int(value), GEMINI_DRAG_INSET), max_value)
-
-
-def _clear_text_calls() -> list[dict[str, Any]]:
- is_mac = _is_mac()
- return [
- {"action": "press", "keys": ["cmd", "a"] if is_mac else ["ctrl", "a"]},
- {"action": "press", "keys": ["backspace" if is_mac else "delete"]},
- ]
-
-
-def _normalize_key_combination(keys: Any) -> list[str] | Any:
- if isinstance(keys, str):
- return [_normalize_key(key) for key in keys.split("+") if key.strip()]
- if isinstance(keys, list):
- return [_normalize_key(key) if isinstance(key, str) else key for key in keys]
- return keys
-
-
-def _normalize_key(key: str) -> str:
- normalized = key.strip().lower()
- aliases = {
- "control": "ctrl",
- "cmd": "cmd",
- "command": "cmd",
- "meta": "cmd" if _is_mac() else "ctrl",
- "return": "enter",
- }
- return aliases.get(normalized, normalized)
-
-
-def _requires_confirmation(safety_decision: Any) -> bool:
- if not isinstance(safety_decision, dict):
- return False
- return safety_decision.get("decision") == "require_confirmation"
-
-
-def _address_bar_calls() -> list[dict[str, Any]]:
- return [{"action": "press", "keys": ["cmd", "l"] if _is_mac() else ["ctrl", "l"]}]
-
-
-def _is_mac() -> bool:
- return platform.system().lower() == "darwin"
-
-
-def _has_image(result: MCPToolResult) -> bool:
- return any(isinstance(block, ImageContent) for block in result.content)
-
-
-def _error_result(message: str) -> MCPToolResult:
- return MCPToolResult(
- content=[TextContent(type="text", text=message)],
- isError=True,
- )
-
-
-def _blocked_result(message: str) -> MCPToolResult:
- return MCPToolResult(
- content=[TextContent(type="text", text=f"__GEMINI_SAFETY_BLOCKED__:{message}")],
- isError=False,
- )
-
-
-__all__ = [
- "GEMINI_COMPUTER_SPEC",
- "GEMINI_COORDINATE_SPACE",
- "GEMINI_DRAG_INSET",
- "PREDEFINED_COMPUTER_USE_FUNCTIONS",
- "SUPPORTED_GEMINI_COMPUTER_USE_MODELS",
- "GeminiComputerTool",
- "normalize_gemini_computer_use_args",
-]
diff --git a/hud/agents/gemini/tools/filesystem.py b/hud/agents/gemini/tools/filesystem.py
index dc4750ee8..8ba89bd39 100644
--- a/hud/agents/gemini/tools/filesystem.py
+++ b/hud/agents/gemini/tools/filesystem.py
@@ -5,11 +5,12 @@
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
+ from hud.agents.tools.base import CallTool
from hud.types import MCPToolResult
from hud.agents.tools import GroupedCapabilityMixin
-from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool
+from .base import GeminiTool, GeminiToolSpec
GEMINI_READ_SPEC = GeminiToolSpec(api_type="read_file", api_name="read_file")
GEMINI_SEARCH_SPEC = GeminiToolSpec(api_type="grep_search", api_name="grep_search")
@@ -17,7 +18,7 @@
GEMINI_LIST_SPEC = GeminiToolSpec(api_type="list_directory", api_name="list_directory")
-class GeminiFilesystemTool(GroupedCapabilityMixin, GeminiFunctionTool):
+class GeminiFilesystemTool(GroupedCapabilityMixin, GeminiTool):
"""Gemini function tool backed by one filesystem environment primitive."""
capability = "filesystem"
@@ -45,16 +46,15 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_READ_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
start = arguments.get("start_line")
end = arguments.get("end_line")
offset = int(start) - 1 if isinstance(start, int) and start > 0 else None
limit = None
if offset is not None and isinstance(start, int) and isinstance(end, int) and end >= start:
limit = end - start + 1
- return await call_tool(
- caller,
- self.env_tool_name,
+ return await super().execute(
+ call_tool,
{
"filePath": _required_str(arguments, "file_path"),
"offset": offset,
@@ -84,10 +84,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_SEARCH_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- return await call_tool(
- caller,
- self.env_tool_name,
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ return await super().execute(
+ call_tool,
{
"pattern": _required_str(arguments, "pattern"),
"path": arguments.get("dir_path"),
@@ -120,10 +119,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_GLOB_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- return await call_tool(
- caller,
- self.env_tool_name,
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ return await super().execute(
+ call_tool,
{
"pattern": _required_str(arguments, "pattern"),
"path": arguments.get("dir_path"),
@@ -156,11 +154,13 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_LIST_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- return await call_tool(
- caller,
- self.env_tool_name,
- {"path": _required_str(arguments, "dir_path"), "ignore": arguments.get("ignore")},
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ return await super().execute(
+ call_tool,
+ {
+ "path": _required_str(arguments, "dir_path"),
+ "ignore": arguments.get("ignore"),
+ },
)
@@ -169,16 +169,3 @@ def _required_str(arguments: dict[str, Any], key: str) -> str:
if not isinstance(value, str) or not value:
raise ValueError(f"{key} is required")
return value
-
-
-__all__ = [
- "GEMINI_GLOB_SPEC",
- "GEMINI_LIST_SPEC",
- "GEMINI_READ_SPEC",
- "GEMINI_SEARCH_SPEC",
- "GeminiFilesystemTool",
- "GeminiGlobTool",
- "GeminiListTool",
- "GeminiReadTool",
- "GeminiSearchTool",
-]
diff --git a/hud/agents/gemini/tools/hosted.py b/hud/agents/gemini/tools/hosted.py
index 25a993a7d..138e1d4de 100644
--- a/hud/agents/gemini/tools/hosted.py
+++ b/hud/agents/gemini/tools/hosted.py
@@ -40,11 +40,3 @@ class GeminiCodeExecutionTool(GeminiHostedTool):
def to_params(self) -> genai_types.Tool:
return genai_types.Tool(code_execution=genai_types.ToolCodeExecution())
-
-
-__all__ = [
- "GeminiCodeExecutionTool",
- "GeminiGoogleSearchTool",
- "GeminiHostedTool",
- "GeminiUrlContextTool",
-]
diff --git a/hud/agents/gemini/tools/memory.py b/hud/agents/gemini/tools/memory.py
index 8aeb14e50..8d91dc2fb 100644
--- a/hud/agents/gemini/tools/memory.py
+++ b/hud/agents/gemini/tools/memory.py
@@ -6,14 +6,15 @@
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
+ from hud.agents.tools.base import CallTool
from hud.types import MCPToolResult
-from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool
+from .base import GeminiTool, GeminiToolSpec
GEMINI_MEMORY_SPEC = GeminiToolSpec(api_type="save_memory", api_name="save_memory")
-class GeminiMemoryTool(GeminiFunctionTool):
+class GeminiMemoryTool(GeminiTool):
"""Translate Gemini save_memory calls into the file-backed memory env primitive."""
name = "save_memory"
@@ -32,21 +33,17 @@ def default_spec(cls, model: str) -> GeminiToolSpec:
del model
return GEMINI_MEMORY_SPEC
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
fact = arguments.get("fact")
if not isinstance(fact, str) or not fact.strip():
raise ValueError("fact is required")
text = fact.strip()
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
- return await call_tool(
- caller,
- self.env_tool_name,
+ return await super().execute(
+ call_tool,
{
"command": "create",
"path": f"/memories/gemini-{digest}.md",
"file_text": f"{text}\n",
},
)
-
-
-__all__ = ["GEMINI_MEMORY_SPEC", "GeminiMemoryTool"]
diff --git a/hud/agents/misc/__init__.py b/hud/agents/misc/__init__.py
index 522faac53..8a048c64d 100644
--- a/hud/agents/misc/__init__.py
+++ b/hud/agents/misc/__init__.py
@@ -2,6 +2,6 @@
from __future__ import annotations
-from .response_agent import ResponseAgent
+from .response_automation import auto_respond
-__all__ = ["ResponseAgent"]
+__all__ = ["auto_respond"]
diff --git a/hud/agents/misc/response_agent.py b/hud/agents/misc/response_agent.py
deleted file mode 100644
index 52f9bfde4..000000000
--- a/hud/agents/misc/response_agent.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from __future__ import annotations
-
-import logging
-from typing import Literal
-
-from openai import AsyncOpenAI
-from openai.types.responses import ResponseOutputText
-
-from hud.settings import settings
-from hud.telemetry import instrument
-
-logger = logging.getLogger(__name__)
-
-ResponseType = Literal["STOP", "CONTINUE"]
-
-DEFAULT_SYSTEM_PROMPT = """\
-You are an assistant that helps determine the appropriate response to an agent's message.
-
-You will receive messages from an agent that is performing tasks for a user.
-Your job is to analyze these messages and respond with one of the following:
-
-- STOP: If the agent indicates it has successfully completed a task or is stuck,
- struggling or says it cannot complete the task, even if phrased as a question
- like "I have entered the right values into this form. Would you like me to do
- anything else?" or "Here is the website. Is there any other information you
- need?" or if the agent has strongly determined it wants to stop the task like
- "The task is infeasible. Can I help you with something else?"
-
-- CONTINUE: If the agent is asking for clarification before proceeding with a task
- like "I'm about to clear cookies from this website. Would you like me to proceed?"
- or "I've entered the right values into this form. Would you like me to continue
- with the rest of the task?"
-
-Respond ONLY with one of these two options."""
-
-
-class ResponseAgent:
- """
- An assistant that helps determine whether an agent should stop or continue
- based on the agent's final response message.
- """
-
- def __init__(
- self,
- model: str = "gpt-5",
- system_prompt: str | None = None,
- ) -> None:
- """
- Initialize the ResponseAgent.
-
- Args:
- model: The model to use via HUD inference gateway (default: "gpt-5").
- Supports any model available through inference.hud.ai.
- system_prompt: Optional custom system prompt for determining responses.
- """
- api_key = settings.api_key
- if not api_key:
- raise ValueError(
- "HUD API key is required for auto_respond. Set HUD_API_KEY environment variable."
- )
-
- self.client: AsyncOpenAI = AsyncOpenAI(
- base_url=settings.hud_gateway_url,
- api_key=api_key,
- )
- self.model = model
- self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-
- @instrument(
- category="agent",
- name="response_agent",
- internal_type="user-message",
- )
- async def determine_response(self, agent_message: str) -> ResponseType:
- """
- Determine whether the agent should stop or continue based on its message.
-
- Args:
- agent_message: The message from the agent
-
- Returns:
- ResponseType: Either "STOP" or "CONTINUE"
- """
- try:
- response = await self.client.responses.create(
- model=self.model,
- instructions=self.system_prompt,
- input=[
- {
- "role": "user",
- "content": (
- f"Agent message: {agent_message}\n\nWhat is the appropriate response?"
- ),
- },
- ],
- reasoning={"effort": "low"},
- max_output_tokens=256,
- extra_headers={"Trace-Id": ""},
- )
-
- text_parts: list[str] = []
- for item in response.output:
- if item.type == "message":
- text_parts.extend(
- content.text
- for content in item.content
- if isinstance(content, ResponseOutputText)
- )
-
- response_text = "".join(text_parts)
- if not response_text:
- return "CONTINUE"
-
- response_text = response_text.strip().upper()
-
- if "STOP" in response_text:
- return "STOP"
- else:
- return "CONTINUE"
-
- except Exception as e:
- logger.warning("Auto-respond failed: %s", e)
- return "CONTINUE"
diff --git a/hud/agents/misc/response_automation.py b/hud/agents/misc/response_automation.py
new file mode 100644
index 000000000..91621843f
--- /dev/null
+++ b/hud/agents/misc/response_automation.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import logging
+from functools import cache
+from typing import Literal
+
+import mcp.types as types
+from openai import AsyncOpenAI
+from openai.types.responses import ResponseOutputText
+
+from hud.settings import settings
+from hud.telemetry import instrument
+
+logger = logging.getLogger(__name__)
+
+ResponseType = Literal["STOP", "CONTINUE"]
+
+DEFAULT_SYSTEM_PROMPT = """\
+You are an assistant that helps determine the appropriate response to an agent's message.
+
+You will receive messages from an agent that is performing tasks for a user.
+Your job is to analyze these messages and respond with one of the following:
+
+- STOP: If the agent indicates it has successfully completed a task or is stuck,
+ struggling or says it cannot complete the task, even if phrased as a question
+ like "I have entered the right values into this form. Would you like me to do
+ anything else?" or "Here is the website. Is there any other information you
+ need?" or if the agent has strongly determined it wants to stop the task like
+ "The task is infeasible. Can I help you with something else?"
+
+- CONTINUE: If the agent is asking for clarification before proceeding with a task
+ like "I'm about to clear cookies from this website. Would you like me to proceed?"
+ or "I've entered the right values into this form. Would you like me to continue
+ with the rest of the task?"
+
+Respond ONLY with one of these two options."""
+
+
+async def auto_respond(
+ content: str | None,
+ *,
+ enabled: bool,
+) -> types.PromptMessage | None:
+ if not enabled or not content:
+ return None
+
+ try:
+ decision = await _determine_response(content)
+ except Exception as exc:
+ logger.warning("Auto-respond failed: %s", exc)
+ return None
+
+ if decision == "STOP":
+ return None
+
+ return types.PromptMessage(
+ role="user",
+ content=types.TextContent(text=decision, type="text"),
+ )
+
+
+@cache
+def _client() -> AsyncOpenAI:
+ api_key = settings.api_key
+ if not api_key:
+ raise ValueError(
+ "HUD API key is required for auto_respond. Set HUD_API_KEY environment variable."
+ )
+
+ return AsyncOpenAI(
+ base_url=settings.hud_gateway_url,
+ api_key=api_key,
+ )
+
+
+@instrument(
+ category="agent",
+ name="response_automation",
+ internal_type="user-message",
+)
+async def _determine_response(
+ agent_message: str,
+ *,
+ model: str = "gpt-5",
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
+) -> ResponseType:
+ response = await _client().responses.create(
+ model=model,
+ instructions=system_prompt,
+ input=[
+ {
+ "role": "user",
+ "content": f"Agent message: {agent_message}\n\nWhat is the appropriate response?",
+ },
+ ],
+ reasoning={"effort": "low"},
+ max_output_tokens=256,
+ extra_headers={"Trace-Id": ""},
+ )
+
+ text_parts: list[str] = []
+ for item in response.output:
+ if item.type == "message":
+ text_parts.extend(
+ content.text for content in item.content if isinstance(content, ResponseOutputText)
+ )
+
+ response_text = "".join(text_parts)
+ if not response_text:
+ return "CONTINUE"
+
+ response_text = response_text.strip().upper()
+ return "STOP" if "STOP" in response_text else "CONTINUE"
diff --git a/hud/agents/openai/agent.py b/hud/agents/openai/agent.py
index 89bdfa50a..34ab08c27 100644
--- a/hud/agents/openai/agent.py
+++ b/hud/agents/openai/agent.py
@@ -2,86 +2,59 @@
from __future__ import annotations
-import copy
import json
import logging
-from inspect import cleandoc
-from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
+from functools import cached_property
+from typing import Any, Literal, cast
import mcp.types as types
from openai import AsyncOpenAI, Omit, OpenAI
from openai.types.responses import (
- FunctionToolParam,
- ResponseComputerToolCallOutputScreenshotParam,
- ResponseFunctionCallOutputItemListParam,
ResponseIncludable,
- ResponseInputFileContentParam,
- ResponseInputImageContentParam,
ResponseInputImageParam,
ResponseInputMessageContentListParam,
ResponseInputParam,
- ResponseInputTextContentParam,
ResponseInputTextParam,
ResponseOutputText,
ToolParam,
)
+from openai.types.responses.easy_input_message_param import EasyInputMessageParam
from openai.types.responses.response_create_params import ToolChoice # noqa: TC002
from openai.types.responses.response_input_param import (
- ComputerCallOutput,
- ComputerCallOutputAcknowledgedSafetyCheck,
- FunctionCallOutput,
Message,
+ ResponseInputItemParam,
)
from openai.types.shared_params.reasoning import Reasoning # noqa: TC002
+from hud.agents import gateway
from hud.agents.base import MCPAgent
-from hud.agents.tools import (
- EnvironmentCapability,
- call_agent_tools,
- capabilities_metadata_from_context,
- discover_environment_capabilities,
- select_hosted_tools,
-)
-from hud.agents.types import OpenAIConfig, OpenAICreateParams
+from hud.agents.types import OpenAIConfig
from hud.settings import settings
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult, Trace
-from hud.utils.strict_schema import ensure_strict_json_schema
+from hud.types import AgentResponse, MCPToolCall
from hud.utils.types import with_signature
-from .tools import OpenAIHostedTool, OpenAIToolSearchTool, openai_tools
-
-if TYPE_CHECKING:
- from .tools import OpenAITool
+from .tools import OpenAIAgentTools
logger = logging.getLogger(__name__)
-class OpenAIAgent(MCPAgent):
+class OpenAIAgent(MCPAgent[ResponseInputItemParam]):
"""Generic OpenAI agent that can execute MCP tools through the Responses API."""
- metadata: ClassVar[dict[str, Any] | None] = None
- config_cls: ClassVar[type[BaseAgentConfig]] = OpenAIConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for OpenAI."""
- return AgentType.OPENAI
-
- @with_signature(OpenAICreateParams)
+ @with_signature(OpenAIConfig)
@classmethod
- def create(cls, **kwargs: Any) -> OpenAIAgent: # pyright: ignore[reportIncompatibleMethodOverride]
- return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
+ def create(cls, **kwargs: object) -> OpenAIAgent: # pyright: ignore[reportIncompatibleMethodOverride]
+ return cls(OpenAIConfig.model_validate(kwargs))
- def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> None:
- super().__init__(params, **kwargs)
+ def __init__(self, config: OpenAIConfig | None = None) -> None:
+ config = config or OpenAIConfig()
+ super().__init__(config)
self.config: OpenAIConfig
model_client = self.config.model_client
if model_client is None:
if settings.api_key:
- from hud.agents.gateway import build_gateway_client
-
- model_client = build_gateway_client("openai")
+ model_client = gateway.build_gateway_client("openai")
elif settings.openai_api_key:
model_client = AsyncOpenAI(api_key=settings.openai_api_key)
if self.config.validate_api_key:
@@ -99,7 +72,7 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N
" access"
)
- self.openai_client: AsyncOpenAI = model_client
+ self.openai_client: AsyncOpenAI = cast("AsyncOpenAI", model_client)
self._model = self.config.model
self.max_output_tokens = self.config.max_output_tokens
self.temperature = self.config.temperature
@@ -109,178 +82,40 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N
self.text = self.config.text
self.truncation: Literal["auto", "disabled"] | None = self.config.truncation
- self._openai_tools: list[ToolParam] = []
- self._openai_native_tools: dict[str, OpenAITool] = {}
- self._tool_name_map: dict[str, str] = {}
- self._environment_capabilities: dict[str, EnvironmentCapability] = {}
- self._tool_search_threshold: int | None = None
-
self.last_response_id: str | None = None
self._message_cursor = 0
- self.pending_call_id: str | None = None
- self.pending_safety_checks: list[Any] = []
-
- def _on_tools_ready(self) -> None:
- """Build OpenAI-specific tool mappings after tools are discovered."""
- self._convert_tools_for_openai()
-
- def _discover_environment_capabilities(
- self, tools: list[types.Tool]
- ) -> dict[str, EnvironmentCapability]:
- return discover_environment_capabilities(
- tools,
- env_metadata=capabilities_metadata_from_context(self.ctx),
- name_fallbacks=openai_tools.name_fallbacks,
- )
-
- def _to_function_tool(self, tool: types.Tool) -> FunctionToolParam | None:
- """Convert an MCP tool to OpenAI function tool format.
-
- Args:
- tool: MCP tool to convert
-
- Returns:
- OpenAI function tool parameter
- """
- if tool.description is None or tool.inputSchema is None:
- raise ValueError(
- cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema.
- Add these by:
- 1. Adding a docstring to your @mcp.tool decorated function for the description
- 2. Using pydantic Field() annotations on function parameters for the schema
- """)
- )
-
- try:
- strict_schema = ensure_strict_json_schema(copy.deepcopy(tool.inputSchema))
- except Exception as e:
- self.console.warning_log(f"Failed to convert tool '{tool.name}' schema to strict: {e}")
- return None
-
- return FunctionToolParam(
- type="function",
- name=tool.name,
- description=tool.description,
- parameters=strict_schema,
- strict=True,
- )
-
- def _convert_tools_for_openai(self) -> None:
- """Convert MCP tools into OpenAI Responses tool definitions."""
- self._openai_tools = []
- self._openai_native_tools = {}
- self._tool_name_map = {}
- self._tool_search_threshold = None
-
- categorized = self._categorized_tools
- capabilities = self._discover_environment_capabilities(self.get_available_tools())
- self._environment_capabilities = capabilities
- provider_backing_tools: set[str] = set()
-
- for capability in capabilities.values():
- if capability.name not in openai_tools.capabilities:
- continue
- openai_tool = openai_tools.tool_for_capability(capability, self.model)
- if openai_tool is None:
- continue
- provider_backing_tools.add(capability.tool_name)
- self._openai_native_tools[openai_tool.name] = openai_tool
- self._tool_name_map[openai_tool.name] = openai_tool.name
- self._openai_tools.append(openai_tool.to_params())
-
- configured_hosted = select_hosted_tools(
- self.config.hosted_tools,
- tool_type=OpenAIHostedTool,
- model=self.model,
- )
- for hosted_tool in configured_hosted:
- self._openai_tools.append(hosted_tool.to_params())
- if isinstance(hosted_tool, OpenAIToolSearchTool):
- self._tool_search_threshold = hosted_tool.threshold
-
- # Process generic tools (function tools)
- for tool in categorized.generic:
- if tool.name in provider_backing_tools:
- continue
- openai_tool = self._to_function_tool(tool)
- if openai_tool:
- self._tool_name_map[tool.name] = tool.name
- self._openai_tools.append(openai_tool)
-
- # Log actual tools being used
- tool_names = sorted(self._tool_name_map.keys())
- self.console.info(
- f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}"
- )
-
- def _extract_tool_call(self, item: Any) -> MCPToolCall | None:
- """Extract an MCPToolCall from a response output item.
-
- Subclasses can override to customize tool call extraction (e.g., routing
- computer_call to a different tool name).
- """
- if item.type == "function_call":
- tool_name = item.name or ""
- target_name = self._tool_name_map.get(tool_name, tool_name)
- arguments = json.loads(item.arguments)
- return MCPToolCall(name=target_name, arguments=arguments, id=item.call_id)
- elif item.type == "computer_call":
- self.pending_safety_checks = item.pending_safety_checks or []
- target_name = self._tool_name_map.get("computer", "computer")
- if hasattr(item, "actions") and item.actions:
- arguments = {"actions": [a.to_dict() for a in item.actions]}
- else:
- arguments = item.action.to_dict()
- return MCPToolCall(name=target_name, arguments=arguments, id=item.call_id)
- elif item.type == "shell_call":
- target_name = "shell"
- return MCPToolCall(name=target_name, arguments=item.action.to_dict(), id=item.call_id)
- return None
-
- async def call_tools(
- self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
- ) -> list[MCPToolResult]:
- """Route OpenAI provider tools through their agent-owned adapters."""
- return await call_agent_tools(self, self._openai_native_tools, tool_call)
-
- async def _run_context(
- self, context: list[types.ContentBlock], *, max_steps: int = 10
- ) -> Trace:
- """Reset internal state before delegating to the base loop."""
- self._reset_response_state()
- return await super()._run_context(context, max_steps=max_steps)
-
- def _reset_response_state(self) -> None:
- self.last_response_id = None
- self._message_cursor = 0
- self.pending_call_id = None
- self.pending_safety_checks = []
-
- async def get_system_messages(self) -> list[types.ContentBlock]:
- """System messages are provided via the `instructions` field."""
- return []
+ @cached_property
+ def tools(self) -> OpenAIAgentTools:
+ return OpenAIAgentTools()
+
+ async def format_messages(
+ self, messages: list[types.PromptMessage]
+ ) -> list[ResponseInputItemParam]:
+ """Convert MCP prompt messages into OpenAI Responses input items."""
+ formatted_messages: list[ResponseInputItemParam] = []
+ for message in messages:
+ match message.content:
+ case types.TextContent() as block:
+ content: ResponseInputMessageContentListParam = [
+ ResponseInputTextParam(type="input_text", text=block.text)
+ ]
+ case types.ImageContent() as block:
+ mime_type = getattr(block, "mimeType", "image/png")
+ content = [
+ ResponseInputImageParam(
+ type="input_image",
+ image_url=f"data:{mime_type};base64,{block.data}",
+ detail="auto",
+ )
+ ]
+ case _:
+ content = [ResponseInputTextParam(type="input_text", text="")]
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Message]:
- """Convert MCP content blocks into OpenAI user messages."""
- content: ResponseInputMessageContentListParam = []
- for block in blocks:
- if isinstance(block, types.TextContent):
- content.append(ResponseInputTextParam(type="input_text", text=block.text))
- elif isinstance(block, types.ImageContent):
- mime_type = getattr(block, "mimeType", "image/png")
- content.append(
- ResponseInputImageParam(
- type="input_image",
- image_url=f"data:{mime_type};base64,{block.data}",
- detail="auto",
- )
- )
- if not content:
- content.append(ResponseInputTextParam(type="input_text", text=""))
- return [Message(role="user", content=content)]
+ formatted_messages.append(EasyInputMessageParam(role=message.role, content=content))
+ return formatted_messages
- async def get_response(self, messages: ResponseInputParam) -> InferenceResult:
+ async def get_response(self, messages: list[ResponseInputItemParam]) -> AgentResponse:
"""Send the latest input items to OpenAI's Responses API."""
new_items: ResponseInputParam = messages[self._message_cursor :]
if not new_items:
@@ -291,33 +126,29 @@ async def get_response(self, messages: ResponseInputParam) -> InferenceResult:
)
]
else:
- self.console.debug("No new messages to send to OpenAI.")
- return InferenceResult(content="", tool_calls=[], done=True)
+ logger.debug("No new messages to send to OpenAI.")
+ return AgentResponse(content="", tool_calls=[], done=True)
- scenario_enable_citations = bool(
- getattr(self.ctx, "scenario_enable_citations", False) if self.ctx is not None else False
- )
include_param: list[ResponseIncludable] | Omit = Omit()
- if scenario_enable_citations:
+ if self.enable_citations:
include_param = ["web_search_call.action.sources"]
- effective_tools: list[ToolParam] = list(self._openai_tools)
- if self._tool_search_threshold is not None:
- fn_count = sum(
- 1 for t in effective_tools if isinstance(t, dict) and t.get("type") == "function"
- )
- if fn_count > self._tool_search_threshold:
+ effective_tools: list[ToolParam] = list(self.tools.params)
+ if self.tools.tool_search_threshold is not None:
+ fn_count = sum(1 for t in effective_tools if t.get("type") == "function")
+ if fn_count > self.tools.tool_search_threshold:
logger.debug(
"tool_search: %d function tools > threshold %d, applying defer_loading",
fn_count,
- self._tool_search_threshold,
+ self.tools.tool_search_threshold,
+ )
+ effective_tools = cast(
+ "list[ToolParam]",
+ [
+ {**t, "defer_loading": True} if t.get("type") == "function" else t
+ for t in effective_tools
+ ],
)
- effective_tools = [ # type: ignore[assignment]
- {**t, "defer_loading": True}
- if isinstance(t, dict) and t.get("type") == "function"
- else t
- for t in effective_tools
- ]
response = await self.openai_client.responses.create(
model=self._model,
@@ -340,227 +171,89 @@ async def get_response(self, messages: ResponseInputParam) -> InferenceResult:
self.last_response_id = response.id
self._message_cursor = len(messages)
- agent_response = InferenceResult(content="", tool_calls=[], done=True)
text_chunks: list[str] = []
reasoning_chunks: list[str] = []
-
- citations: list[dict[str, Any]] = []
+ citations: list[dict[str, object]] = []
+ tool_calls: list[MCPToolCall] = []
for item in response.output:
- if item.type == "message":
- for content_block in item.content:
- if isinstance(content_block, ResponseOutputText):
+ match item.type:
+ case "message":
+ for content_block in item.content:
+ if not isinstance(content_block, ResponseOutputText):
+ continue
if content_block.text:
text_chunks.append(content_block.text)
- # Extract citations from annotations
- if content_block.annotations:
- for ann in content_block.annotations:
- ann_type = getattr(ann, "type", "")
- if ann_type == "url_citation":
- cit_obj = getattr(ann, "url_citation", ann)
+ for ann in content_block.annotations or []:
+ match ann.type:
+ case "url_citation":
+ citation = ann
citations.append(
{
"type": "url_citation",
- "text": getattr(cit_obj, "title", "") or "",
- "source": getattr(cit_obj, "url", "") or "",
- "title": getattr(cit_obj, "title", None),
- "start_index": getattr(ann, "start_index", None),
- "end_index": getattr(ann, "end_index", None),
+ "text": citation.title,
+ "source": citation.url,
+ "title": citation.title,
+ "start_index": citation.start_index,
+ "end_index": citation.end_index,
}
)
- elif ann_type == "file_citation":
- cit_obj = getattr(ann, "file_citation", ann)
+ case "file_citation":
+ citation = ann
citations.append(
{
"type": "file_citation",
- "text": getattr(cit_obj, "filename", "") or "",
- "source": getattr(cit_obj, "file_id", "") or "",
- "title": getattr(cit_obj, "filename", None),
- "start_index": getattr(ann, "start_index", None),
- "end_index": getattr(ann, "end_index", None),
+ "text": citation.filename,
+ "source": citation.file_id,
+ "title": citation.filename,
}
)
- elif item.type == "reasoning":
- reasoning_chunks.append("".join(summary.text for summary in item.summary))
- else:
- tool_call = self._extract_tool_call(item)
- if tool_call is not None:
- agent_response.tool_calls.append(tool_call)
-
- if agent_response.tool_calls:
- agent_response.done = False
-
- agent_response.content = "".join(text_chunks)
- agent_response.citations = citations
- if reasoning_chunks:
- agent_response.reasoning = "\n".join(reasoning_chunks)
- return agent_response
-
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[Any]:
- """Convert MCP tool outputs into Responses input items.
-
- Detects computer tool results and formats them as ComputerCallOutput
- with screenshots. Non-computer calls are formatted as FunctionCallOutput.
- """
- computer_tool_name = self._tool_name_map.get("computer")
- has_computer_call = bool(computer_tool_name) and any(
- c.name == computer_tool_name for c in tool_calls
- )
- has_native_call = any(c.name in self._openai_native_tools for c in tool_calls)
- if not has_computer_call and not has_native_call:
- return list(await self._format_function_results(tool_calls, tool_results))
-
- remaining_calls: list[MCPToolCall] = []
- remaining_results: list[MCPToolResult] = []
- computer_outputs: list[ComputerCallOutput] = []
- native_outputs: list[dict[str, Any]] = []
- ordering: list[tuple[str, int]] = []
-
- for call, result in zip(tool_calls, tool_results, strict=False):
- if call.name == computer_tool_name:
- screenshot = self._extract_latest_screenshot(result)
- if not screenshot:
- raise ValueError(
- "Computer tool result missing screenshot. "
- "The tool must always return a screenshot for computer_call_output."
- )
- call_id = call.id or self.pending_call_id
- if not call_id:
- self.console.warning_log("Computer tool call missing ID; skipping output.")
- continue
- acknowledged_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] = []
- for check in self.pending_safety_checks:
- if hasattr(check, "model_dump"):
- acknowledged_checks.append(check.model_dump()) # type: ignore[arg-type]
- elif isinstance(check, dict):
- acknowledged_checks.append(check) # type: ignore[arg-type]
- output_payload = ComputerCallOutput(
- type="computer_call_output",
- call_id=call_id,
- output=cast(
- "ResponseComputerToolCallOutputScreenshotParam",
- {
- "type": "computer_screenshot",
- "image_url": f"data:image/png;base64,{screenshot}",
- "detail": "original",
- },
- ),
- )
- if acknowledged_checks:
- output_payload["acknowledged_safety_checks"] = acknowledged_checks
- computer_outputs.append(output_payload)
- self.pending_call_id = None
- self.pending_safety_checks = []
- ordering.append(("computer", len(computer_outputs) - 1))
- elif call.name in self._openai_native_tools:
- native_outputs.append(
- self._openai_native_tools[call.name].format_result(call, result)
- )
- ordering.append(("native", len(native_outputs) - 1))
- else:
- remaining_calls.append(call)
- remaining_results.append(result)
- ordering.append(("function", len(remaining_calls) - 1))
-
- formatted: list[Any] = []
- function_outputs: list[FunctionCallOutput] = []
- if remaining_calls:
- function_outputs = await self._format_function_results(
- remaining_calls, remaining_results
- )
-
- for kind, idx in ordering:
- if kind == "computer" and idx < len(computer_outputs):
- formatted.append(computer_outputs[idx])
- elif kind == "native" and idx < len(native_outputs):
- formatted.append(native_outputs[idx])
- elif kind == "function" and idx < len(function_outputs):
- formatted.append(function_outputs[idx])
- return formatted
-
- def _extract_latest_screenshot(self, result: MCPToolResult) -> str | None:
- """Extract the latest screenshot from a tool result."""
- if not result.content:
- return None
- for content in reversed(result.content):
- if isinstance(content, types.ImageContent):
- return content.data
- if isinstance(content, types.TextContent) and result.isError:
- self.console.error_log(f"Computer tool error: {content.text}")
- return None
-
- async def _format_function_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[FunctionCallOutput]:
- """Convert MCP tool outputs into function call output items."""
- formatted: list[FunctionCallOutput] = []
- for call, result in zip(tool_calls, tool_results, strict=False):
- if not call.id:
- self.console.warning_log(f"Tool '{call.name}' missing call_id; skipping output.")
- continue
-
- output_items: ResponseFunctionCallOutputItemListParam = []
- if result.isError:
- output_items.append(
- ResponseInputTextParam(type="input_text", text="[tool_error] true")
- )
-
- if result.structuredContent is not None:
- output_items.append(
- ResponseInputTextParam(
- type="input_text", text=json.dumps(result.structuredContent, default=str)
- )
- )
-
- for block in result.content:
- match block:
- case types.TextContent():
- output_items.append(
- ResponseInputTextContentParam(type="input_text", text=block.text)
+ case _:
+ continue
+ case "reasoning":
+ reasoning_chunks.append("".join(summary.text for summary in item.summary))
+ case "function_call":
+ tool_name = item.name or ""
+ tool_calls.append(
+ MCPToolCall(
+ name=self.tools.name_map.get(tool_name, tool_name),
+ arguments=json.loads(item.arguments),
+ id=item.call_id,
)
- case types.ImageContent():
- mime_type = getattr(block, "mimeType", "image/png")
- output_items.append(
- ResponseInputImageContentParam(
- type="input_image",
- image_url=f"data:{mime_type};base64,{block.data}",
- )
- )
- case types.ResourceLink():
- output_items.append(
- ResponseInputFileContentParam(
- type="input_file", file_url=str(block.uri)
- )
+ )
+ case "computer_call":
+ if item.actions:
+ arguments = {"actions": [action.to_dict() for action in item.actions]}
+ elif item.action is not None:
+ arguments = item.action.to_dict()
+ else:
+ raise ValueError("OpenAI computer_call missing action")
+ call: dict[str, Any] = {
+ "name": self.tools.name_map.get("computer", "computer"),
+ "arguments": arguments,
+ "id": item.call_id,
+ }
+ if item.pending_safety_checks:
+ call["pending_safety_checks"] = [
+ check.model_dump() if hasattr(check, "model_dump") else check
+ for check in item.pending_safety_checks
+ ]
+ tool_calls.append(MCPToolCall.model_validate(call))
+ case "shell_call":
+ tool_calls.append(
+ MCPToolCall(
+ name="shell",
+ arguments=item.action.to_dict(),
+ id=item.call_id,
)
- case types.EmbeddedResource():
- match block.resource:
- case types.TextResourceContents():
- output_items.append(
- ResponseInputTextContentParam(
- type="input_text", text=block.resource.text
- )
- )
- case types.BlobResourceContents():
- output_items.append(
- ResponseInputFileContentParam(
- type="input_file", file_data=block.resource.blob
- )
- )
- case _:
- self.console.warning_log(
- f"Unknown resource type: {type(block.resource)}"
- )
- case _:
- self.console.warning_log(f"Unknown content block type: {type(block)}")
-
- if not output_items:
- output_items.append(ResponseInputTextParam(type="input_text", text=""))
+ )
+ case _:
+ continue
- formatted.append(
- FunctionCallOutput(
- type="function_call_output", call_id=call.id, output=output_items
- ),
- )
- return formatted
+ return AgentResponse(
+ content="".join(text_chunks),
+ reasoning="\n".join(reasoning_chunks) if reasoning_chunks else None,
+ citations=citations,
+ tool_calls=tool_calls,
+ done=not tool_calls,
+ )
diff --git a/hud/agents/openai/tools/__init__.py b/hud/agents/openai/tools/__init__.py
index 1c1ffe271..b2e5222d7 100644
--- a/hud/agents/openai/tools/__init__.py
+++ b/hud/agents/openai/tools/__init__.py
@@ -2,55 +2,46 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, ClassVar
-from hud.agents.tools import AgentToolRegistry
+from openai.types.responses import ToolParam
-from .base import OpenAITool
-from .coding import (
- OPENAI_SHELL_SPEC,
- OpenAIShellTool,
-)
-from .computer import OPENAI_COMPUTER_SPEC, OpenAIComputerTool
+from hud.agents.tools import AgentTool, AgentTools
+
+from .base import OpenAIFunctionTool, OpenAITool
+from .coding import OpenAIShellTool
+from .computer import OpenAIComputerTool
from .hosted import OpenAICodeInterpreterTool, OpenAIHostedTool, OpenAIToolSearchTool
+if TYPE_CHECKING:
+ from collections.abc import Mapping
+
-@dataclass(frozen=True)
-class OpenAIToolRegistry(AgentToolRegistry[OpenAITool]):
- """Registry for OpenAI harness tools."""
+class OpenAIAgentTools(AgentTools[OpenAITool, ToolParam]):
+ """Prepared OpenAI Responses tool state for a run."""
- tool_classes: tuple[type[OpenAITool], ...] = (
+ native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = (
OpenAIComputerTool,
OpenAIShellTool,
)
- name_fallbacks: dict[str, tuple[str, ...]] = field(
- default_factory=lambda: {
- "computer": ("computer", "openai_computer"),
- "shell": ("bash",),
- "editor": ("edit",),
- }
- )
+ function_tool_class = OpenAIFunctionTool
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {
+ "computer": ("computer", "openai_computer"),
+ "shell": ("bash",),
+ "editor": ("edit",),
+ }
@property
- def api_types(self) -> frozenset[str]:
- return frozenset(cls.name for cls in self.tool_classes)
-
- @property
- def roles(self) -> frozenset[str]:
- return self.capabilities
-
+ def tool_search_threshold(self) -> int | None:
+ for hosted_tool in self.hosted_tools:
+ if isinstance(hosted_tool, OpenAIToolSearchTool):
+ return hosted_tool.threshold
+ return None
-openai_tools = OpenAIToolRegistry()
__all__ = [
- "OPENAI_COMPUTER_SPEC",
- "OPENAI_SHELL_SPEC",
+ "OpenAIAgentTools",
"OpenAICodeInterpreterTool",
- "OpenAIComputerTool",
"OpenAIHostedTool",
- "OpenAIShellTool",
- "OpenAITool",
- "OpenAIToolRegistry",
"OpenAIToolSearchTool",
- "openai_tools",
]
diff --git a/hud/agents/openai/tools/apply_patch.py b/hud/agents/openai/tools/apply_patch.py
index 90913df5e..03fffa654 100644
--- a/hud/agents/openai/tools/apply_patch.py
+++ b/hud/agents/openai/tools/apply_patch.py
@@ -1,10 +1,10 @@
+# pyright: reportUnusedFunction=false
"""OpenAI apply_patch parser helpers."""
from __future__ import annotations
from dataclasses import dataclass, field
-from enum import Enum
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from collections.abc import Callable
@@ -14,45 +14,24 @@ class DiffError(ValueError):
"""Exception raised when diff parsing or application fails."""
-class ActionType(str, Enum):
- ADD = "add"
- DELETE = "delete"
- UPDATE = "update"
-
-
-@dataclass
-class FileChange:
- type: ActionType
- old_content: str | None = None
- new_content: str | None = None
- move_path: str | None = None
-
-
-@dataclass
-class Commit:
- changes: dict[str, FileChange] = field(default_factory=dict)
+ActionType = Literal["add", "delete", "update"]
@dataclass
class Chunk:
orig_index: int = -1 # line index of the first line in the original file
- del_lines: list[str] = field(default_factory=list)
- ins_lines: list[str] = field(default_factory=list)
+ del_lines: list[str] = field(default_factory=list[str])
+ ins_lines: list[str] = field(default_factory=list[str])
@dataclass
class PatchAction:
type: ActionType
new_file: str | None = None
- chunks: list[Chunk] = field(default_factory=list)
+ chunks: list[Chunk] = field(default_factory=list[Chunk])
move_path: str | None = None
-@dataclass
-class Patch:
- actions: dict[str, PatchAction] = field(default_factory=dict)
-
-
class Parser:
"""Parser for V4A diff format."""
@@ -60,7 +39,7 @@ def __init__(self, current_files: dict[str, str], lines: list[str], index: int =
self.current_files = current_files
self.lines = lines
self.index = index
- self.patch = Patch()
+ self.actions: dict[str, PatchAction] = {}
self.fuzz = 0
def is_done(self, prefixes: tuple[str, ...] | None = None) -> bool:
@@ -68,19 +47,11 @@ def is_done(self, prefixes: tuple[str, ...] | None = None) -> bool:
return True
return prefixes is not None and self.lines[self.index].startswith(prefixes)
- def startswith(self, prefix: str | tuple[str, ...]) -> bool:
- if self.index >= len(self.lines):
- raise DiffError(f"Unexpected end of patch at index {self.index}")
- return self.lines[self.index].startswith(prefix)
-
- def read_str(self, prefix: str = "", return_everything: bool = False) -> str:
+ def read_str(self, prefix: str = "") -> str:
if self.index >= len(self.lines):
return "" # At EOF, no match possible
if self.lines[self.index].startswith(prefix):
- if return_everything:
- text = self.lines[self.index]
- else:
- text = self.lines[self.index][len(prefix) :]
+ text = self.lines[self.index][len(prefix) :]
self.index += 1
return text
return ""
@@ -89,7 +60,7 @@ def parse(self) -> None:
while not self.is_done(("*** End Patch",)):
path = self.read_str("*** Update File: ")
if path:
- if path in self.patch.actions:
+ if path in self.actions:
raise DiffError(f"Update File Error: Duplicate Path: {path}")
move_to = self.read_str("*** Move to: ")
if path not in self.current_files:
@@ -97,33 +68,33 @@ def parse(self) -> None:
text = self.current_files[path]
action = self.parse_update_file(text)
action.move_path = move_to if move_to else None
- self.patch.actions[path] = action
+ self.actions[path] = action
continue
path = self.read_str("*** Delete File: ")
if path:
- if path in self.patch.actions:
+ if path in self.actions:
raise DiffError(f"Delete File Error: Duplicate Path: {path}")
if path not in self.current_files:
raise DiffError(f"Delete File Error: Missing File: {path}")
- self.patch.actions[path] = PatchAction(type=ActionType.DELETE)
+ self.actions[path] = PatchAction(type="delete")
continue
path = self.read_str("*** Add File: ")
if path:
- if path in self.patch.actions:
+ if path in self.actions:
raise DiffError(f"Add File Error: Duplicate Path: {path}")
- self.patch.actions[path] = self.parse_add_file()
+ self.actions[path] = self.parse_add_file()
continue
raise DiffError(f"Unknown Line: {self.lines[self.index]}")
- if self.index >= len(self.lines) or not self.startswith("*** End Patch"):
+ if self.index >= len(self.lines) or not self.lines[self.index].startswith("*** End Patch"):
raise DiffError("Missing End Patch")
self.index += 1
def parse_update_file(self, text: str) -> PatchAction:
- action = PatchAction(type=ActionType.UPDATE)
+ action = PatchAction(type="update")
lines = text.split("\n")
index = 0
@@ -136,27 +107,28 @@ def parse_update_file(self, text: str) -> PatchAction:
"*** End of File",
)
):
- def_str = self.read_str("@@ ")
- section_str = ""
- if not def_str and self.lines[self.index] == "@@":
- section_str = self.lines[self.index]
+ section_anchor = self.read_str("@@ ")
+ has_section_marker = False
+ if not section_anchor and self.lines[self.index] == "@@":
+ has_section_marker = True
self.index += 1
- if not (def_str or section_str or index == 0):
+ if not (section_anchor or has_section_marker or index == 0):
raise DiffError(f"Invalid Line:\n{self.lines[self.index]}")
- if def_str.strip():
+ if section_anchor.strip():
found = False
- if not [s for s in lines[:index] if s == def_str]:
- for i, s in enumerate(lines[index:], index):
- if s == def_str:
+ if not any(line == section_anchor for line in lines[:index]):
+ for i, line in enumerate(lines[index:], index):
+ if line == section_anchor:
index = i + 1
found = True
break
- if not found and not [s for s in lines[:index] if s.strip() == def_str.strip()]:
- for i, s in enumerate(lines[index:], index):
- if s.strip() == def_str.strip():
+ stripped_anchor = section_anchor.strip()
+ if not found and not any(line.strip() == stripped_anchor for line in lines[:index]):
+ for i, line in enumerate(lines[index:], index):
+ if line.strip() == stripped_anchor:
index = i + 1
self.fuzz += 1
found = True
@@ -174,9 +146,9 @@ def parse_update_file(self, text: str) -> PatchAction:
self.fuzz += fuzz
- for ch in chunks:
- ch.orig_index += new_index
- action.chunks.append(ch)
+ for chunk in chunks:
+ chunk.orig_index += new_index
+ action.chunks.append(chunk)
index = new_index + len(next_chunk_context)
self.index = end_patch_index
@@ -184,16 +156,15 @@ def parse_update_file(self, text: str) -> PatchAction:
return action
def parse_add_file(self) -> PatchAction:
- lines = []
+ lines: list[str] = []
while not self.is_done(
("*** End Patch", "*** Update File:", "*** Delete File:", "*** Add File:")
):
- s = self.read_str()
- if not s.startswith("+"):
- raise DiffError(f"Invalid Add File Line: {s}")
- s = s[1:]
- lines.append(s)
- return PatchAction(type=ActionType.ADD, new_file="\n".join(lines))
+ line = self.read_str()
+ if not line.startswith("+"):
+ raise DiffError(f"Invalid Add File Line: {line}")
+ lines.append(line[1:])
+ return PatchAction(type="add", new_file="\n".join(lines))
def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]:
old: list[str] = []
@@ -204,9 +175,23 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]:
orig_index = self.index
index = self.index
+ def flush_chunk() -> None:
+ nonlocal del_lines, ins_lines
+ if not (ins_lines or del_lines):
+ return
+ chunks.append(
+ Chunk(
+ orig_index=len(old) - len(del_lines),
+ del_lines=del_lines,
+ ins_lines=ins_lines,
+ )
+ )
+ del_lines = []
+ ins_lines = []
+
while index < len(self.lines):
- s = self.lines[index]
- if s.startswith(
+ line = self.lines[index]
+ if line.startswith(
(
"@@",
"*** End Patch",
@@ -217,56 +202,40 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]:
)
):
break
- if s == "***":
+ if line == "***":
break
- elif s.startswith("***"):
- raise DiffError(f"Invalid Line: {s}")
+ elif line.startswith("***"):
+ raise DiffError(f"Invalid Line: {line}")
index += 1
last_mode = mode
- if s == "":
- s = " "
+ if line == "":
+ line = " "
- if s[0] == "+":
+ if line[0] == "+":
mode = "add"
- elif s[0] == "-":
+ elif line[0] == "-":
mode = "delete"
- elif s[0] == " ":
+ elif line[0] == " ":
mode = "keep"
else:
- raise DiffError(f"Invalid Line: {s}")
+ raise DiffError(f"Invalid Line: {line}")
- s = s[1:]
+ line = line[1:]
if mode == "keep" and last_mode != mode:
- if ins_lines or del_lines:
- chunks.append(
- Chunk(
- orig_index=len(old) - len(del_lines),
- del_lines=del_lines,
- ins_lines=ins_lines,
- )
- )
- del_lines = []
- ins_lines = []
+ flush_chunk()
if mode == "delete":
- del_lines.append(s)
- old.append(s)
+ del_lines.append(line)
+ old.append(line)
elif mode == "add":
- ins_lines.append(s)
+ ins_lines.append(line)
elif mode == "keep":
- old.append(s)
+ old.append(line)
- if ins_lines or del_lines:
- chunks.append(
- Chunk(
- orig_index=len(old) - len(del_lines),
- del_lines=del_lines,
- ins_lines=ins_lines,
- )
- )
+ flush_chunk()
if index < len(self.lines) and self.lines[index] == "*** End of File":
index += 1
@@ -278,116 +247,82 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]:
return old, chunks, index, False
-def _find_context_core(lines: list[str], context: list[str], start: int) -> tuple[int, int]:
+def _find_context(lines: list[str], context: list[str], start: int, eof: bool) -> tuple[int, int]:
if not context:
return start, 0
- # Prefer identical
- for i in range(start, len(lines)):
- if lines[i : i + len(context)] == context:
- return i, 0
-
- # RStrip is ok
- for i in range(start, len(lines)):
- if [s.rstrip() for s in lines[i : i + len(context)]] == [s.rstrip() for s in context]:
- return i, 1
-
- # Fine, Strip is ok too
- for i in range(start, len(lines)):
- if [s.strip() for s in lines[i : i + len(context)]] == [s.strip() for s in context]:
- return i, 100
-
- return -1, 0
-
+ search_starts = [len(lines) - len(context), start] if eof else [start]
+ rstripped_context = [line.rstrip() for line in context]
+ stripped_context = [line.strip() for line in context]
-def _find_context(lines: list[str], context: list[str], start: int, eof: bool) -> tuple[int, int]:
- if eof:
- new_index, fuzz = _find_context_core(lines, context, len(lines) - len(context))
- if new_index != -1:
- return new_index, fuzz
- new_index, fuzz = _find_context_core(lines, context, start)
- return new_index, fuzz + 10000
- return _find_context_core(lines, context, start)
-
-
-def _get_updated_file(text: str, action: PatchAction, path: str) -> str:
- assert action.type == ActionType.UPDATE
- orig_lines = text.split("\n")
- dest_lines = []
- orig_index = 0
-
- for chunk in action.chunks:
- if chunk.orig_index > len(orig_lines):
- raise DiffError(
- f"_get_updated_file: {path}: chunk.orig_index {chunk.orig_index} "
- f"> len(lines) {len(orig_lines)}"
- )
- if orig_index > chunk.orig_index:
- raise DiffError(
- f"_get_updated_file: {path}: orig_index {orig_index} "
- f"> chunk.orig_index {chunk.orig_index}"
- )
+ for attempt, search_start in enumerate(search_starts):
+ fuzz_offset = 10000 if eof and attempt > 0 else 0
- dest_lines.extend(orig_lines[orig_index : chunk.orig_index])
- orig_index = chunk.orig_index
+ for i in range(search_start, len(lines)):
+ candidate = lines[i : i + len(context)]
+ if candidate == context:
+ return i, fuzz_offset
- if chunk.ins_lines:
- dest_lines.extend(chunk.ins_lines)
+ for i in range(search_start, len(lines)):
+ candidate = lines[i : i + len(context)]
+ if [line.rstrip() for line in candidate] == rstripped_context:
+ return i, fuzz_offset + 1
- orig_index += len(chunk.del_lines)
+ for i in range(search_start, len(lines)):
+ candidate = lines[i : i + len(context)]
+ if [line.strip() for line in candidate] == stripped_context:
+ return i, fuzz_offset + 100
- dest_lines.extend(orig_lines[orig_index:])
- return "\n".join(dest_lines)
+ return -1, 0
-def _text_to_patch(text: str, orig: dict[str, str]) -> tuple[Patch, int]:
+def _text_to_patch(text: str, orig: dict[str, str]) -> tuple[dict[str, PatchAction], int]:
lines = text.strip().split("\n")
if len(lines) < 2 or not lines[0].startswith("*** Begin Patch") or lines[-1] != "*** End Patch":
raise DiffError("Invalid patch text")
parser = Parser(current_files=orig, lines=lines, index=1)
parser.parse()
- return parser.patch, parser.fuzz
+ return parser.actions, parser.fuzz
+
+
+def _apply_patch(
+ patch: dict[str, PatchAction],
+ orig: dict[str, str],
+ write_fn: Callable[[str, str | None], None],
+ remove_fn: Callable[[str], None],
+) -> None:
+ for path, action in patch.items():
+ match action.type:
+ case "delete":
+ remove_fn(path)
+ case "add":
+ write_fn(path, action.new_file)
+ case "update":
+ orig_lines = orig[path].split("\n")
+ dest_lines: list[str] = []
+ orig_index = 0
+
+ for chunk in action.chunks:
+ if chunk.orig_index > len(orig_lines):
+ raise DiffError(
+ f"_apply_patch: {path}: chunk.orig_index {chunk.orig_index} "
+ f"> len(lines) {len(orig_lines)}"
+ )
+ if orig_index > chunk.orig_index:
+ raise DiffError(
+ f"_apply_patch: {path}: orig_index {orig_index} "
+ f"> chunk.orig_index {chunk.orig_index}"
+ )
+ dest_lines.extend(orig_lines[orig_index : chunk.orig_index])
+ dest_lines.extend(chunk.ins_lines)
+ orig_index = chunk.orig_index + len(chunk.del_lines)
-def _identify_files_needed(text: str) -> list[str]:
- lines = text.strip().split("\n")
- result = set()
- for line in lines:
- if line.startswith("*** Update File: "):
- result.add(line[len("*** Update File: ") :])
- if line.startswith("*** Delete File: "):
- result.add(line[len("*** Delete File: ") :])
- return list(result)
-
-
-def _patch_to_commit(patch: Patch, orig: dict[str, str]) -> Commit:
- commit = Commit()
- for path, action in patch.actions.items():
- if action.type == ActionType.DELETE:
- commit.changes[path] = FileChange(type=ActionType.DELETE, old_content=orig[path])
- elif action.type == ActionType.ADD:
- commit.changes[path] = FileChange(type=ActionType.ADD, new_content=action.new_file)
- elif action.type == ActionType.UPDATE:
- new_content = _get_updated_file(text=orig[path], action=action, path=path)
- commit.changes[path] = FileChange(
- type=ActionType.UPDATE,
- old_content=orig[path],
- new_content=new_content,
- move_path=action.move_path,
- )
- return commit
-
-
-def _apply_commit(commit: Commit, write_fn: Callable, remove_fn: Callable) -> None:
- for path, change in commit.changes.items():
- if change.type == ActionType.DELETE:
- remove_fn(path)
- elif change.type == ActionType.ADD:
- write_fn(path, change.new_content)
- elif change.type == ActionType.UPDATE:
- if change.move_path:
- write_fn(change.move_path, change.new_content)
- remove_fn(path)
- else:
- write_fn(path, change.new_content)
+ dest_lines.extend(orig_lines[orig_index:])
+ new_content = "\n".join(dest_lines)
+ if action.move_path:
+ write_fn(action.move_path, new_content)
+ remove_fn(path)
+ else:
+ write_fn(path, new_content)
diff --git a/hud/agents/openai/tools/base.py b/hud/agents/openai/tools/base.py
index f5074bb4c..523a5087e 100644
--- a/hud/agents/openai/tools/base.py
+++ b/hud/agents/openai/tools/base.py
@@ -2,41 +2,155 @@
from __future__ import annotations
+import copy
+import json
+import logging
from abc import ABC
-from typing import TYPE_CHECKING, Any
+from inspect import cleandoc
+from typing import TYPE_CHECKING, Any, cast
-from mcp.types import TextContent
+from mcp import types
+from openai.types.responses import (
+ FunctionToolParam,
+ ResponseFunctionCallOutputItemListParam,
+ ResponseInputFileContentParam,
+ ResponseInputImageContentParam,
+ ResponseInputTextContentParam,
+ ResponseInputTextParam,
+ ToolParam,
+)
+from openai.types.responses.response_input_param import FunctionCallOutput
-from hud.agents import tools as _agent_tools
-from hud.agents.tools import AgentTool, AgentToolSpec, CallTool
+from hud.agents.tools import AgentTool, AgentToolSpec
+from hud.utils.strict_schema import ensure_strict_json_schema
if TYPE_CHECKING:
- from openai.types.responses import ToolParam
+ from openai.types.responses import ResponseInputItemParam
from hud.types import MCPToolCall, MCPToolResult
-else:
- ToolParam = Any
+
+logger = logging.getLogger(__name__)
OpenAIToolSpec = AgentToolSpec
-call_tool = _agent_tools.call_tool
-class OpenAITool(AgentTool["ToolParam"], ABC):
+class OpenAITool(AgentTool[ToolParam], ABC):
"""Agent-side OpenAI provider tool backed by an environment tool."""
- def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]:
+ def format_result(
+ self, call: MCPToolCall, result: MCPToolResult
+ ) -> ResponseInputItemParam | None:
"""Format a generic provider tool result for the OpenAI Responses API."""
- return {
- "type": "function_call_output",
- "call_id": call.id,
- "output": result_text(result),
- }
+ if not call.id:
+ logger.warning("Tool '%s' missing call_id; skipping output.", call.name)
+ return None
+
+ output_items: ResponseFunctionCallOutputItemListParam = []
+ if result.isError:
+ output_items.append(
+ ResponseInputTextContentParam(type="input_text", text="[tool_error] true")
+ )
+
+ if result.structuredContent is not None:
+ output_items.append(
+ ResponseInputTextContentParam(
+ type="input_text",
+ text=json.dumps(result.structuredContent, default=str),
+ )
+ )
+
+ for block in result.content:
+ match block:
+ case types.TextContent():
+ output_items.append(
+ ResponseInputTextContentParam(type="input_text", text=block.text)
+ )
+ case types.ImageContent():
+ mime_type = getattr(block, "mimeType", "image/png")
+ output_items.append(
+ ResponseInputImageContentParam(
+ type="input_image",
+ image_url=f"data:{mime_type};base64,{block.data}",
+ )
+ )
+ case types.ResourceLink():
+ output_items.append(
+ ResponseInputFileContentParam(type="input_file", file_url=str(block.uri))
+ )
+ case types.EmbeddedResource(resource=types.TextResourceContents() as resource):
+ output_items.append(
+ ResponseInputTextContentParam(type="input_text", text=resource.text)
+ )
+ case types.EmbeddedResource(resource=types.BlobResourceContents() as resource):
+ output_items.append(
+ ResponseInputFileContentParam(type="input_file", file_data=resource.blob)
+ )
+ case types.EmbeddedResource():
+ logger.warning("Unknown resource type: %s", type(block.resource))
+ case _:
+ logger.warning("Unknown content block type: %s", type(block))
+
+ if not output_items:
+ output_items.append(ResponseInputTextParam(type="input_text", text=""))
+
+ return FunctionCallOutput(type="function_call_output", call_id=call.id, output=output_items)
+
+
+class OpenAIFunctionTool(OpenAITool):
+ """Generic OpenAI function tool backed by an MCP tool."""
+
+ name = "function"
+ capability = "function"
+
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
+ description: str,
+ parameters: dict[str, Any],
+ ) -> None:
+ super().__init__(
+ env_tool_name=env_tool_name,
+ spec=OpenAIToolSpec(api_type="function", api_name=env_tool_name),
+ )
+ self.description = description
+ self.parameters = parameters
+
+ @classmethod
+ def from_tool(cls, tool: types.Tool) -> OpenAIFunctionTool | None:
+ if tool.description is None:
+ raise ValueError(
+ cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema.
+ Add these by:
+ 1. Adding a docstring to your @mcp.tool decorated function for the description
+ 2. Using pydantic Field() annotations on function parameters for the schema
+ """)
+ )
+ try:
+ parameters = ensure_strict_json_schema(copy.deepcopy(tool.inputSchema))
+ except Exception as e:
+ logger.warning("Failed to convert tool '%s' schema to strict: %s", tool.name, e)
+ return None
-def result_text(result: MCPToolResult) -> str:
- """Return text content from an MCP tool result."""
- parts = [block.text for block in result.content if isinstance(block, TextContent)]
- return "\n".join(part for part in parts if part)
+ return cls(
+ env_tool_name=tool.name,
+ description=tool.description,
+ parameters=parameters,
+ )
+ @property
+ def provider_name(self) -> str:
+ return self.env_tool_name
-__all__ = ["CallTool", "OpenAITool", "OpenAIToolSpec", "call_tool", "result_text"]
+ def to_params(self) -> ToolParam:
+ return cast(
+ "ToolParam",
+ FunctionToolParam(
+ type="function",
+ name=self.provider_name,
+ description=self.description,
+ parameters=self.parameters,
+ strict=True,
+ ),
+ )
diff --git a/hud/agents/openai/tools/coding.py b/hud/agents/openai/tools/coding.py
index 0fa2f6176..6bb6efa4d 100644
--- a/hud/agents/openai/tools/coding.py
+++ b/hud/agents/openai/tools/coding.py
@@ -2,14 +2,17 @@
from __future__ import annotations
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any, cast
from mcp.types import TextContent
-from openai.types.responses import FunctionShellToolParam, ToolParam
+from openai.types.responses import FunctionShellToolParam, ResponseInputItemParam, ToolParam
from hud.types import MCPToolCall, MCPToolResult
-from .base import CallTool, OpenAITool, OpenAIToolSpec, call_tool, result_text
+from .base import OpenAITool, OpenAIToolSpec
+
+if TYPE_CHECKING:
+ from hud.agents.tools.base import CallTool
OPENAI_SHELL_SPEC = OpenAIToolSpec(
api_type="shell",
@@ -45,21 +48,28 @@ def to_params(self) -> ToolParam:
FunctionShellToolParam(type="shell", environment={"type": "local"}),
)
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- commands = arguments.get("commands")
- if isinstance(commands, str):
- commands = [commands]
- if not isinstance(commands, list) or not all(isinstance(cmd, str) for cmd in commands):
- return _provider_result(
- "shell",
- "commands must be a list of strings",
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ def invalid_commands_result() -> MCPToolResult:
+ text = "commands must be a list of strings"
+ return _shell_result(
+ text,
is_error=True,
structured={
- "output": [_shell_output("", "commands must be a list of strings", 1)],
+ "output": [_shell_output("", text, 1)],
"max_output_length": arguments.get("max_output_length"),
},
)
+ commands = arguments.get("commands")
+ if isinstance(commands, str):
+ commands = [commands]
+ if not isinstance(commands, list):
+ return invalid_commands_result()
+ raw_commands = cast("list[Any]", commands)
+ if not all(isinstance(cmd, str) for cmd in raw_commands):
+ return invalid_commands_result()
+ command_list = cast("list[str]", raw_commands)
+
outputs: list[dict[str, Any]] = []
text_parts: list[str] = []
is_error = False
@@ -67,13 +77,12 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR
timeout_ms = arguments.get("timeout_ms")
if isinstance(timeout_ms, int):
env_arguments["timeout_seconds"] = timeout_ms / 1000.0
- for command in commands:
- result = await call_tool(
- caller,
- self.env_tool_name,
+ for command in command_list:
+ result = await super().execute(
+ call_tool,
{"command": command, **env_arguments},
)
- text = result_text(result)
+ text = _result_text(result)
if result.isError:
outputs.append(_shell_output("", text, 1))
is_error = True
@@ -82,8 +91,7 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR
if text:
text_parts.append(text)
- return _provider_result(
- "shell",
+ return _shell_result(
"\n".join(text_parts),
is_error=is_error,
structured={
@@ -92,11 +100,11 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR
},
)
- def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]:
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> ResponseInputItemParam:
structured = result.structuredContent if isinstance(result.structuredContent, dict) else {}
output = structured.get("output")
if not isinstance(output, list):
- output = [_shell_output("", result_text(result), 1 if result.isError else 0)]
+ output = [_shell_output("", _result_text(result), 1 if result.isError else 0)]
response: dict[str, Any] = {
"type": "shell_call_output",
@@ -107,17 +115,16 @@ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, A
max_output_length = structured.get("max_output_length")
if isinstance(max_output_length, int):
response["max_output_length"] = max_output_length
- return response
+ return cast("ResponseInputItemParam", response)
-def _provider_result(
- provider_tool: str,
+def _shell_result(
text: str,
*,
is_error: bool = False,
structured: dict[str, Any] | None = None,
) -> MCPToolResult:
- payload = {"provider_tool": provider_tool, **(structured or {})}
+ payload = {"provider_tool": "shell", **(structured or {})}
return MCPToolResult(
content=[TextContent(type="text", text=text)] if text else [],
isError=is_error,
@@ -125,15 +132,14 @@ def _provider_result(
)
+def _result_text(result: MCPToolResult) -> str:
+ parts = [block.text for block in result.content if isinstance(block, TextContent)]
+ return "\n".join(part for part in parts if part)
+
+
def _shell_output(stdout: str, stderr: str, exit_code: int) -> dict[str, Any]:
return {
"stdout": stdout,
"stderr": stderr,
"outcome": {"type": "exit", "exit_code": exit_code},
}
-
-
-__all__ = [
- "OPENAI_SHELL_SPEC",
- "OpenAIShellTool",
-]
diff --git a/hud/agents/openai/tools/computer.py b/hud/agents/openai/tools/computer.py
index acfb39cbe..748a31601 100644
--- a/hud/agents/openai/tools/computer.py
+++ b/hud/agents/openai/tools/computer.py
@@ -4,14 +4,29 @@
from typing import TYPE_CHECKING, Any, cast
-from mcp.types import ImageContent, TextContent
+from mcp.types import TextContent
+from openai.types.responses.response_input_param import ComputerCallOutput
-from hud.types import MCPToolResult
+from hud.agents.tools.computer import (
+ computer_error_result,
+ execute_computer_calls,
+ last_image_data,
+)
+from hud.types import MCPToolCall, MCPToolResult
-from .base import CallTool, OpenAITool, OpenAIToolSpec, call_tool
+from .base import OpenAITool, OpenAIToolSpec
if TYPE_CHECKING:
- from openai.types.responses import ComputerToolParam
+ from openai.types.responses import (
+ ComputerToolParam,
+ ResponseComputerToolCallOutputScreenshotParam,
+ ResponseInputItemParam,
+ )
+ from openai.types.responses.response_input_param import (
+ ComputerCallOutputAcknowledgedSafetyCheck,
+ )
+
+ from hud.agents.tools.base import CallTool
else:
ComputerToolParam = Any
@@ -88,59 +103,92 @@ def __init__(
def to_params(self) -> ComputerToolParam:
return cast("ComputerToolParam", {"type": "computer"})
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> ResponseInputItemParam:
+ screenshot = last_image_data(result)
+ if not screenshot:
+ raise ValueError(
+ "Computer tool result missing screenshot. "
+ "The tool must always return a screenshot for computer_call_output."
+ )
+
+ output = ComputerCallOutput(
+ type="computer_call_output",
+ call_id=call.id,
+ output=cast(
+ "ResponseComputerToolCallOutputScreenshotParam",
+ {
+ "type": "computer_screenshot",
+ "image_url": f"data:image/png;base64,{screenshot}",
+ "detail": "original",
+ },
+ ),
+ )
+
+ checks = (call.model_extra or {}).get("pending_safety_checks")
+ if isinstance(checks, list):
+ acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] = []
+ for raw_check in cast("list[Any]", checks):
+ check: Any = raw_check
+ if hasattr(check, "model_dump"):
+ acknowledged.append(
+ cast("ComputerCallOutputAcknowledgedSafetyCheck", check.model_dump())
+ )
+ elif isinstance(check, dict):
+ acknowledged.append(cast("ComputerCallOutputAcknowledgedSafetyCheck", check))
+ if acknowledged:
+ output["acknowledged_safety_checks"] = acknowledged
+ return cast("ResponseInputItemParam", output)
+
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
actions = arguments.get("actions")
if isinstance(actions, list):
- if not actions:
- return _error_result("actions list is empty")
+ action_list = cast("list[Any]", actions)
+ if not action_list:
+ return computer_error_result("actions list is empty")
result = MCPToolResult(content=[], isError=False)
- for index, action in enumerate(actions):
- if not isinstance(action, dict):
- return _error_result("actions must be objects")
+ for index, raw_action in enumerate(action_list):
+ action = cast("dict[str, Any]", raw_action)
+ if not isinstance(raw_action, dict):
+ return computer_error_result("actions must be objects")
result = await self._execute_one(
- caller,
+ call_tool,
action,
- ensure_screenshot=index == len(actions) - 1,
+ ensure_screenshot=index == len(action_list) - 1,
)
if result.isError:
return result
return result
- return await self._execute_one(caller, arguments, ensure_screenshot=True)
+ return await self._execute_one(call_tool, arguments, ensure_screenshot=True)
async def _execute_one(
self,
- caller: CallTool,
+ call_tool: CallTool,
arguments: dict[str, Any],
*,
ensure_screenshot: bool,
) -> MCPToolResult:
action_type = arguments.get("type")
if not isinstance(action_type, str):
- return _error_result("type is required")
+ return computer_error_result("type is required")
if action_type == "response":
text = arguments.get("text")
if not isinstance(text, str):
- return _error_result("text is required for response")
+ return computer_error_result("text is required for response")
return MCPToolResult(content=[TextContent(type="text", text=text)], isError=False)
env_arguments = self._env_arguments(arguments)
- result = await call_tool(caller, self.env_tool_name, env_arguments)
- if (
- ensure_screenshot
- and action_type in _SCREENSHOT_ACTIONS
- and action_type != "screenshot"
- and not _has_image(result)
- and not result.isError
- ):
- screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"})
- if not screenshot.isError and screenshot.content:
- result = MCPToolResult(
- content=[*result.content, *screenshot.content],
- isError=result.isError,
- )
- return result
+ return await execute_computer_calls(
+ call_tool,
+ env_tool_name=self.env_tool_name,
+ calls=[env_arguments],
+ ensure_screenshot=(
+ ensure_screenshot
+ and action_type in _SCREENSHOT_ACTIONS
+ and action_type != "screenshot"
+ ),
+ )
def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
action_type = arguments.get("type")
@@ -148,11 +196,18 @@ def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
if action_type == "screenshot":
return {"action": "screenshot"}
if action_type == "click":
+ button = arguments.get("button")
+ if button == "wheel":
+ button_name = "middle"
+ elif isinstance(button, str):
+ button_name = button
+ else:
+ button_name = "left"
return {
"action": "click",
"x": arguments.get("x"),
"y": arguments.get("y"),
- "button": _map_button(arguments.get("button")),
+ "button": button_name,
"hold_keys": _hold_keys(arguments.get("keys")),
}
if action_type == "double_click":
@@ -187,7 +242,10 @@ def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
keys = arguments.get("keys")
if not isinstance(keys, list):
keys = []
- return {"action": "press", "keys": [_map_key(str(key)) for key in keys]}
+ return {
+ "action": "press",
+ "keys": [_map_key(str(key)) for key in cast("list[Any]", keys)],
+ }
if action_type == "drag":
return {
"action": "drag",
@@ -207,24 +265,4 @@ def _map_key(key: str) -> str:
def _hold_keys(keys: Any) -> list[str] | None:
if not isinstance(keys, list):
return None
- return [_map_key(str(key)) for key in keys]
-
-
-def _map_button(button: Any) -> str:
- if button == "wheel":
- return "middle"
- return button if isinstance(button, str) else "left"
-
-
-def _has_image(result: MCPToolResult) -> bool:
- return any(isinstance(block, ImageContent) for block in result.content)
-
-
-def _error_result(message: str) -> MCPToolResult:
- return MCPToolResult(
- content=[TextContent(type="text", text=message)],
- isError=True,
- )
-
-
-__all__ = ["OPENAI_COMPUTER_SPEC", "OpenAIComputerTool"]
+ return [_map_key(str(key)) for key in cast("list[Any]", keys)]
diff --git a/hud/agents/openai/tools/hosted.py b/hud/agents/openai/tools/hosted.py
index 0f13be9ba..b182bd93d 100644
--- a/hud/agents/openai/tools/hosted.py
+++ b/hud/agents/openai/tools/hosted.py
@@ -45,10 +45,3 @@ class OpenAIToolSearchTool(OpenAIHostedTool):
def to_params(self) -> ToolParam:
return cast("ToolParam", {"type": "tool_search"})
-
-
-__all__ = [
- "OpenAICodeInterpreterTool",
- "OpenAIHostedTool",
- "OpenAIToolSearchTool",
-]
diff --git a/hud/agents/openai_compatible/__init__.py b/hud/agents/openai_compatible/__init__.py
index 3cecd79d2..fc9746f1c 100644
--- a/hud/agents/openai_compatible/__init__.py
+++ b/hud/agents/openai_compatible/__init__.py
@@ -1,6 +1,5 @@
"""OpenAI-compatible agent harness support."""
from .agent import OpenAIChatAgent
-from .tools import openai_compatible_tools
-__all__ = ["OpenAIChatAgent", "openai_compatible_tools"]
+__all__ = ["OpenAIChatAgent"]
diff --git a/hud/agents/openai_compatible/agent.py b/hud/agents/openai_compatible/agent.py
index 74a464459..5c2351e50 100644
--- a/hud/agents/openai_compatible/agent.py
+++ b/hud/agents/openai_compatible/agent.py
@@ -18,52 +18,37 @@
import json
import logging
-from typing import TYPE_CHECKING, Any, ClassVar, cast
+from functools import cached_property
+from typing import Any, cast
import mcp.types as types
from openai import AsyncOpenAI
+from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from hud.agents.base import MCPAgent
-from hud.agents.tools import (
- AgentTool,
- EnvironmentCapability,
- call_agent_tools,
- capabilities_metadata_from_context,
- discover_environment_capabilities,
-)
-from hud.agents.types import OpenAIChatConfig, OpenAIChatCreateParams
+from hud.agents.types import OpenAIChatConfig
from hud.settings import settings
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-from hud.utils.hud_console import HUDConsole
+from hud.types import AgentResponse, MCPToolCall
from hud.utils.types import with_signature
-from .tools import OpenAICompatibleToolParam, openai_compatible_tools
-
-if TYPE_CHECKING:
- from openai.types.chat import ChatCompletionToolParam
-
+from .tools import (
+ OpenAICompatibleAgentTools,
+)
logger = logging.getLogger(__name__)
-class OpenAIChatAgent(MCPAgent):
+class OpenAIChatAgent(MCPAgent[ChatCompletionMessageParam]):
"""MCP-enabled agent that speaks the OpenAI *chat.completions* protocol."""
- metadata: ClassVar[dict[str, Any] | None] = None
- config_cls: ClassVar[type[BaseAgentConfig]] = OpenAIChatConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for OpenAI-compatible agents."""
- return AgentType.OPENAI_COMPATIBLE
-
- @with_signature(OpenAIChatCreateParams)
+ @with_signature(OpenAIChatConfig)
@classmethod
def create(cls, **kwargs: Any) -> OpenAIChatAgent: # pyright: ignore[reportIncompatibleMethodOverride]
- return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
+ return cls(OpenAIChatConfig(**kwargs))
- def __init__(self, params: OpenAIChatCreateParams | None = None, **kwargs: Any) -> None:
- super().__init__(params, **kwargs)
+ def __init__(self, config: OpenAIChatConfig | None = None) -> None:
+ config = config or OpenAIChatConfig()
+ super().__init__(config)
self.config: OpenAIChatConfig
if (
@@ -100,76 +85,25 @@ def __init__(self, params: OpenAIChatCreateParams | None = None, **kwargs: Any)
# If a specific checkpoint is requested, inject it into extra_body
# so the HUD gateway routes to the exact checkpoint for inference.
if self.config.checkpoint:
- extra_body = self.completion_kwargs.get("extra_body") or {}
+ extra_body: dict[str, Any] = dict(self.completion_kwargs.get("extra_body") or {})
extra_body["checkpoint"] = self.config.checkpoint
self.completion_kwargs["extra_body"] = extra_body
- self.mcp_schemas: list[ChatCompletionToolParam] = []
- self.hud_console = HUDConsole(logger=logger)
- self._openai_compatible_tool_params: list[OpenAICompatibleToolParam] = []
- self._openai_compatible_native_tools: dict[
- str,
- AgentTool[OpenAICompatibleToolParam],
- ] = {}
- self._environment_capabilities: dict[str, EnvironmentCapability] = {}
- self._openai_compatible_backing_tools: set[str] = set()
-
self._continuation_token_ids: list[int] | None = None
self._continuation_message_count: int | None = None
- def _on_tools_ready(self) -> None:
- self._convert_tools_for_openai_compatible()
-
- def _discover_environment_capabilities(
- self, tools: list[types.Tool]
- ) -> dict[str, EnvironmentCapability]:
- return discover_environment_capabilities(
- tools,
- env_metadata=capabilities_metadata_from_context(self.ctx),
- name_fallbacks=openai_compatible_tools.name_fallbacks,
- )
-
- def _convert_tools_for_openai_compatible(self) -> None:
- """Build OpenAI-compatible native tool mappings from environment capabilities."""
- self._openai_compatible_tool_params = []
- self._openai_compatible_native_tools = {}
- self._openai_compatible_backing_tools = set()
-
- capabilities = self._discover_environment_capabilities(self.get_available_tools())
- self._environment_capabilities = capabilities
-
- for capability in capabilities.values():
- if capability.name not in openai_compatible_tools.capabilities:
- continue
- for tool in openai_compatible_tools.tools_for_capability(capability, self.model):
- self._openai_compatible_backing_tools.add(tool.env_tool_name)
- self._openai_compatible_native_tools[tool.name] = tool
- self._openai_compatible_tool_params.append(tool.to_params())
-
- def _oai_to_mcp(self, tool_call: Any) -> MCPToolCall: # type: ignore[valid-type]
- """Convert an OpenAI ``tool_call`` to :class:`MCPToolCall`."""
- args = json.loads(tool_call.function.arguments or "{}")
- if isinstance(args, list):
- args = args[0]
- if not isinstance(args, dict):
- args = {}
- return MCPToolCall(
- id=tool_call.id,
- name=tool_call.function.name,
- arguments=args,
- )
-
- async def get_system_messages(self) -> list[dict[str, Any]]:
- """Get system messages for OpenAI."""
- if self.system_prompt is not None:
- return [{"role": "system", "content": self.system_prompt}]
- else:
- return []
-
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]:
- """Format blocks for OpenAI."""
- content = []
- for block in blocks:
+ @cached_property
+ def tools(self) -> OpenAICompatibleAgentTools:
+ return OpenAICompatibleAgentTools()
+
+ async def format_messages(
+ self, messages: list[types.PromptMessage]
+ ) -> list[ChatCompletionMessageParam]:
+ """Format MCP prompt messages for OpenAI-compatible chat."""
+ formatted_messages: list[ChatCompletionMessageParam] = []
+ for message in messages:
+ content: list[dict[str, Any]] = []
+ block = message.content
if isinstance(block, types.TextContent):
content.append({"type": "text", "text": block.text})
elif isinstance(block, types.ImageContent):
@@ -180,146 +114,54 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str
}
)
- return [{"role": "user", "content": content}]
-
- def _sanitize_schema_for_openai(self, schema: dict) -> dict:
- """Convert MCP JSON Schema to OpenAI-compatible format.
-
- Handles unsupported features like anyOf and prefixItems.
- """
- if not isinstance(schema, dict):
- return schema
-
- sanitized = {}
-
- for key, value in schema.items():
- if key == "anyOf" and isinstance(value, list):
- # Handle anyOf patterns (usually for nullable fields)
- non_null_types = [
- v for v in value if not (isinstance(v, dict) and v.get("type") == "null")
- ]
- if non_null_types:
- # Use the first non-null type
- sanitized.update(self._sanitize_schema_for_openai(non_null_types[0]))
- else:
- sanitized["type"] = "string" # Fallback
-
- elif key == "prefixItems":
- # Convert prefixItems to simple items
- sanitized["type"] = "array"
- if isinstance(value, list) and value:
- # Use the type from the first item as the items schema
- first_item = value[0]
- if isinstance(first_item, dict):
- sanitized["items"] = {"type": first_item.get("type", "string")}
- else:
- sanitized["items"] = {"type": "string"}
-
- elif key == "properties" and isinstance(value, dict):
- # Recursively sanitize property schemas
- sanitized[key] = {
- prop_name: self._sanitize_schema_for_openai(prop_schema)
- for prop_name, prop_schema in value.items()
- }
+ formatted_messages.append(
+ cast(
+ "ChatCompletionMessageParam",
+ {"role": message.role, "content": content},
+ )
+ )
+ return formatted_messages
- elif key == "items" and isinstance(value, dict):
- # Recursively sanitize items schema
- sanitized[key] = self._sanitize_schema_for_openai(value)
-
- elif key in (
- "type",
- "description",
- "enum",
- "required",
- "default",
- "minimum",
- "maximum",
- "minItems",
- "maxItems",
- ):
- # These are supported by OpenAI
- sanitized[key] = value
-
- return sanitized or {"type": "object"}
-
- def get_tool_schemas(self) -> list[OpenAICompatibleToolParam]:
- tool_schemas = [
- schema
- for schema in super().get_tool_schemas()
- if schema["name"] not in self._openai_compatible_backing_tools
- ]
- openai_tools = list(self._openai_compatible_tool_params)
- for schema in tool_schemas:
- parameters = schema.get("parameters", {})
-
- if parameters:
- sanitized_params = self._sanitize_schema_for_openai(parameters)
- else:
- sanitized_params = {"type": "object", "properties": {}}
-
- openai_tool: ChatCompletionToolParam = {
- "type": "function",
- "function": {
- "name": schema["name"],
- "description": schema.get("description", ""),
- "parameters": sanitized_params,
- },
- }
- openai_tools.append(openai_tool)
- return openai_tools
-
- async def call_tools(
- self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
- ) -> list[MCPToolResult]:
- """Route OpenAI-compatible provider tools through agent-owned translators."""
- return await call_agent_tools(self, self._openai_compatible_native_tools, tool_call)
-
- async def _invoke_chat_completion(
- self,
- *,
- messages: list[Any],
- tools: list[dict] | None,
- extra: dict[str, Any],
- ) -> Any:
- if self.oai is None:
- raise ValueError("openai_client is required for OpenAIChatAgent")
- # default transport = OpenAI SDK
- return await self.oai.chat.completions.create(
- model=self.config.model,
- messages=messages,
- tools=tools, # type: ignore ready ChatCompletionToolParam-shaped
- **extra,
- ) # type: ignore
-
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
+ async def get_response(self, messages: list[ChatCompletionMessageParam]) -> AgentResponse:
"""Send chat request to OpenAI and convert the response."""
- # Convert MCP tool schemas to OpenAI format
- tools = cast("list[ChatCompletionToolParam]", self.get_tool_schemas())
+ reserved_kwargs = {"model", "messages", "stream", "tools"}
+ request_kwargs = {
+ key: value
+ for key, value in self.completion_kwargs.items()
+ if key not in reserved_kwargs
+ }
+ provider_body: dict[str, Any] = dict(request_kwargs.pop("extra_body", None) or {})
+ return_token_ids = bool(provider_body.get("return_token_ids"))
- protected_keys = {"model", "messages", "tools"}
- extra = {k: v for k, v in (self.completion_kwargs or {}).items() if k not in protected_keys}
- extra_body = extra.get("extra_body") or {}
- return_token_ids = extra_body.get("return_token_ids")
+ if self.tools.params:
+ provider_body["tools"] = self.tools.params
if return_token_ids and self._continuation_token_ids and self._continuation_message_count:
- extra_body["prompt_token_ids"] = self._continuation_token_ids
- extra_body["continuation_from"] = self._continuation_message_count
- extra["extra_body"] = extra_body
+ provider_body["prompt_token_ids"] = self._continuation_token_ids
+ provider_body["continuation_from"] = self._continuation_message_count
+
+ if provider_body:
+ request_kwargs["extra_body"] = provider_body
try:
- response = await self._invoke_chat_completion(
- messages=messages,
- tools=tools, # type: ignore
- extra=extra,
+ response: ChatCompletion = await self.oai.chat.completions.create(
+ model=self.config.model,
+ messages=(
+ [{"role": "system", "content": self.system_prompt}, *messages]
+ if self.system_prompt is not None
+ else messages
+ ),
+ stream=False,
+ **request_kwargs,
)
except Exception as e:
error_content = f"Error getting response {e}"
if "Invalid JSON" in str(e):
error_content = "Invalid JSON, response was truncated"
- self.hud_console.warning_log(error_content)
+ logger.warning(error_content)
- return InferenceResult(
+ return AgentResponse(
content=error_content,
tool_calls=[],
done=True,
@@ -328,24 +170,33 @@ async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
)
choice = response.choices[0]
- msg = choice.message
- assistant_msg: dict[str, Any] = {"role": "assistant"}
-
- if msg.content:
- assistant_msg["content"] = msg.content
+ message = choice.message
+ function_calls = [
+ tool_call for tool_call in message.tool_calls or [] if tool_call.type == "function"
+ ]
- if msg.tool_calls:
- serialized_tool_calls = []
- for tc in msg.tool_calls:
- serialized_tc = {
- "id": tc.id,
+ assistant_message = message.model_dump(exclude_none=True)
+ reasoning_content = getattr(message, "reasoning_content", None)
+ reasoning = reasoning_content if isinstance(reasoning_content, str) else None
+ if not reasoning:
+ raw_reasoning = getattr(message, "reasoning", None)
+ reasoning = raw_reasoning if isinstance(raw_reasoning, str) else None
+ for field in ("reasoning_content", "reasoning", "reasoning_details"):
+ if value := getattr(message, field, None):
+ assistant_message[field] = value
+ if function_calls:
+ assistant_message["tool_calls"] = [
+ {
+ "id": tool_call.id,
"type": "function",
- "function": {"name": tc.function.name, "arguments": tc.function.arguments},
+ "function": {
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
}
- serialized_tool_calls.append(serialized_tc)
- assistant_msg["tool_calls"] = serialized_tool_calls
-
- messages.append(assistant_msg)
+ for tool_call in function_calls
+ ]
+ messages.append(cast("ChatCompletionMessageParam", assistant_message))
if return_token_ids:
prompt_token_ids = getattr(choice, "prompt_token_ids", None)
@@ -354,91 +205,23 @@ async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
self._continuation_token_ids = list(prompt_token_ids) + list(token_ids)
self._continuation_message_count = len(messages)
- tool_calls = []
- if msg.tool_calls:
- for tc in msg.tool_calls:
- if tc.function.name is not None: # type: ignore
- # _oai_to_mcp returns a single MCPToolCall; append it
- tool_calls.append(self._oai_to_mcp(tc)) # noqa: PERF401
-
- # Only stop on length (token limit), never on "stop"
- done = choice.finish_reason == "length"
- if done:
- self.hud_console.info_log(f"Done decision: finish_reason={choice.finish_reason}")
-
- return InferenceResult(
- content=msg.content or "",
- reasoning=getattr(msg, "reasoning_content", None),
+ tool_calls: list[MCPToolCall] = []
+ for tool_call in function_calls:
+ raw_args = json.loads(tool_call.function.arguments or "{}")
+ arguments = cast("dict[str, Any]", raw_args) if isinstance(raw_args, dict) else {}
+ tool_calls.append(
+ MCPToolCall(
+ id=tool_call.id,
+ name=tool_call.function.name,
+ arguments=arguments,
+ )
+ )
+
+ return AgentResponse(
+ content=message.content or "",
+ reasoning=reasoning,
+ info={"finish_reason": choice.finish_reason},
tool_calls=tool_calls,
- done=done,
+ done=not tool_calls,
raw=response,
)
-
- async def format_tool_results(
- self,
- tool_calls: list[MCPToolCall],
- tool_results: list[MCPToolResult],
- ) -> list[dict[str, Any]]:
- """Render MCP tool results as OpenAI messages.
-
- Note: OpenAI tool messages only support string content.
- When images are present, we return both a tool message and a user message.
- """
- rendered: list[dict[str, Any]] = []
-
- # Separate text and image content
- image_parts = []
- for call, res in zip(tool_calls, tool_results, strict=False):
- # Use structuredContent.result if available, otherwise use content
- text_parts = []
- items = res.content
- if not res.content and res.structuredContent:
- items = [res.structuredContent.get("result", res.content)]
-
- for item in items:
- if isinstance(item, dict):
- if item.get("type") == "text":
- text_parts.append(item.get("text", ""))
- elif item.get("type") == "image":
- mime_type = item.get("mimeType", "image/png")
- data = item.get("data", "")
- image_parts.append(
- {
- "type": "image_url",
- "image_url": {"url": f"data:{mime_type};base64,{data}"},
- }
- )
- elif isinstance(item, types.TextContent):
- text_parts.append(item.text)
- elif isinstance(item, types.ImageContent):
- image_parts.append(
- {
- "type": "image_url",
- "image_url": {"url": f"data:{item.mimeType};base64,{item.data}"},
- }
- )
-
- text_content = "".join(text_parts) if text_parts else "Tool executed successfully"
- rendered.append(
- {
- "role": "tool",
- "tool_call_id": call.id,
- "content": text_content,
- }
- )
-
- # If there are images, add them as a separate user message
- if image_parts:
- # Add a user message with the images
- content_with_images = [
- {"type": "text", "text": "Tool returned the following:"},
- image_parts[-1],
- ]
- rendered.append(
- {
- "role": "user",
- "content": content_with_images,
- }
- )
-
- return rendered
diff --git a/hud/agents/openai_compatible/tools/__init__.py b/hud/agents/openai_compatible/tools/__init__.py
index 94f800b76..1c408f184 100644
--- a/hud/agents/openai_compatible/tools/__init__.py
+++ b/hud/agents/openai_compatible/tools/__init__.py
@@ -2,31 +2,33 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, ClassVar
-from hud.agents.tools import AgentTool, AgentToolRegistry
+from hud.agents.tools import AgentTool, AgentTools
-from .computer import (
- GLM_COMPUTER_SPEC,
- QWEN_COMPUTER_SPEC,
- GLMComputerTool,
- QwenComputerTool,
+from .base import (
+ OpenAICompatibleFunctionTool,
+ OpenAICompatibleToolParam,
)
from .filesystem import (
- FilesystemTool,
GlobTool,
GrepTool,
ListTool,
ReadTool,
)
-from .types import OpenAICompatibleToolParam
+from .glm_computer import GLMComputerTool
+from .qwen_computer import QwenComputerTool
+if TYPE_CHECKING:
+ from collections.abc import Mapping
-@dataclass(frozen=True)
-class OpenAICompatibleToolRegistry(AgentToolRegistry[AgentTool[OpenAICompatibleToolParam]]):
- """Registry for OpenAI-compatible harness tools."""
- tool_classes: tuple[type[AgentTool[OpenAICompatibleToolParam]], ...] = (
+class OpenAICompatibleAgentTools(
+ AgentTools[AgentTool[OpenAICompatibleToolParam], OpenAICompatibleToolParam]
+):
+ """Prepared OpenAI-compatible chat tool state for a run."""
+
+ native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = (
GLMComputerTool,
QwenComputerTool,
ReadTool,
@@ -34,43 +36,19 @@ class OpenAICompatibleToolRegistry(AgentToolRegistry[AgentTool[OpenAICompatibleT
GlobTool,
ListTool,
)
- name_fallbacks: dict[str, tuple[str, ...]] = field(
- default_factory=lambda: {
- "computer": (
- "computer",
- "hud_computer",
- "openai_computer",
- "glm_computer",
- "qwen_computer",
- ),
- "filesystem": ("read", "grep", "glob", "list"),
- }
- )
-
- @property
- def api_types(self) -> frozenset[str]:
- api_types: set[str] = set()
- for cls in self.tool_classes:
- spec = cls.default_spec("unknown")
- if spec is not None and spec.api_type != "function":
- api_types.add(spec.api_type)
- api_types.update(getattr(cls, "ignored_api_types", frozenset()))
- return frozenset(api_types)
-
+ function_tool_class = OpenAICompatibleFunctionTool
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {
+ "computer": (
+ "computer",
+ "hud_computer",
+ "openai_computer",
+ "glm_computer",
+ "qwen_computer",
+ ),
+ "filesystem": ("read", "grep", "glob", "list"),
+ }
-openai_compatible_tools = OpenAICompatibleToolRegistry()
__all__ = [
- "GLM_COMPUTER_SPEC",
- "QWEN_COMPUTER_SPEC",
- "FilesystemTool",
- "GLMComputerTool",
- "GlobTool",
- "GrepTool",
- "ListTool",
- "OpenAICompatibleToolParam",
- "OpenAICompatibleToolRegistry",
- "QwenComputerTool",
- "ReadTool",
- "openai_compatible_tools",
+ "OpenAICompatibleAgentTools",
]
diff --git a/hud/agents/openai_compatible/tools/base.py b/hud/agents/openai_compatible/tools/base.py
new file mode 100644
index 000000000..2d11866be
--- /dev/null
+++ b/hud/agents/openai_compatible/tools/base.py
@@ -0,0 +1,180 @@
+"""OpenAI-compatible agent-owned tool setup."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, TypeAlias, cast
+
+import mcp.types as mcp_types
+
+from hud.agents.tools import AgentTool, AgentToolSpec
+
+if TYPE_CHECKING:
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
+
+ from hud.types import MCPToolCall, MCPToolResult
+
+ from .qwen_computer import QwenComputerUseToolParam
+
+OpenAICompatibleToolParam: TypeAlias = "ChatCompletionToolParam | QwenComputerUseToolParam"
+
+
+class OpenAICompatibleTool(AgentTool[OpenAICompatibleToolParam]):
+ """Agent-side OpenAI-compatible tool backed by an environment tool."""
+
+ def format_result(
+ self, call: MCPToolCall, result: MCPToolResult
+ ) -> ChatCompletionMessageParam | list[ChatCompletionMessageParam]:
+ text_parts: list[str] = []
+ image_parts: list[dict[str, Any]] = []
+ items: list[Any] = list(result.content)
+ if not result.content and result.structuredContent:
+ items = [result.structuredContent.get("result", result.content)]
+
+ for item in items:
+ if isinstance(item, dict):
+ item_dict = cast("dict[str, Any]", item)
+ if item_dict.get("type") == "text":
+ text_parts.append(str(item_dict.get("text", "")))
+ elif item_dict.get("type") == "image":
+ mime_type = str(item_dict.get("mimeType", "image/png"))
+ data = str(item_dict.get("data", ""))
+ image_parts.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:{mime_type};base64,{data}"},
+ }
+ )
+ elif isinstance(item, mcp_types.TextContent):
+ text_parts.append(item.text)
+ elif isinstance(item, mcp_types.ImageContent):
+ image_parts.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:{item.mimeType};base64,{item.data}"},
+ }
+ )
+
+ tool_message = cast(
+ "ChatCompletionMessageParam",
+ {
+ "role": "tool",
+ "tool_call_id": call.id,
+ "content": "".join(text_parts) if text_parts else "Tool executed successfully",
+ },
+ )
+ if not image_parts:
+ return tool_message
+ return [
+ tool_message,
+ cast(
+ "ChatCompletionMessageParam",
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Tool returned the following:"},
+ image_parts[-1],
+ ],
+ },
+ ),
+ ]
+
+
+class OpenAICompatibleFunctionTool(OpenAICompatibleTool):
+ """Regular environment tool exposed as an OpenAI-compatible function."""
+
+ name = "function"
+ capability = "function"
+
+ def __init__(self, *, env_tool_name: str, params: OpenAICompatibleToolParam) -> None:
+ super().__init__(
+ env_tool_name=env_tool_name,
+ spec=AgentToolSpec(api_type="function", api_name=env_tool_name),
+ )
+ self.params = params
+
+ @classmethod
+ def from_tool(cls, tool: mcp_types.Tool) -> OpenAICompatibleFunctionTool:
+ return cls(env_tool_name=tool.name, params=openai_compatible_tool_param(tool))
+
+ @property
+ def provider_name(self) -> str:
+ return self.env_tool_name
+
+ def to_params(self) -> OpenAICompatibleToolParam:
+ return self.params
+
+
+def openai_compatible_tool_param(tool: mcp_types.Tool) -> OpenAICompatibleToolParam:
+ parameters = tool.inputSchema
+ sanitized_params: dict[str, Any] = (
+ _sanitize_schema_for_openai(parameters)
+ if parameters
+ else {"type": "object", "properties": {}}
+ )
+
+ return cast(
+ "OpenAICompatibleToolParam",
+ {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description or f"Call {tool.name}",
+ "parameters": sanitized_params,
+ },
+ },
+ )
+
+
+def _sanitize_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]:
+ """Convert MCP JSON Schema to OpenAI-compatible format."""
+ sanitized: dict[str, Any] = {}
+
+ for key, value in schema.items():
+ if key == "anyOf" and isinstance(value, list):
+ any_of_items = cast("list[Any]", value)
+ non_null_types: list[dict[str, Any]] = [
+ cast("dict[str, Any]", item)
+ for item in any_of_items
+ if isinstance(item, dict) and cast("dict[str, Any]", item).get("type") != "null"
+ ]
+ if non_null_types:
+ sanitized.update(_sanitize_schema_for_openai(non_null_types[0]))
+ else:
+ sanitized["type"] = "string"
+
+ elif key == "prefixItems" and isinstance(value, list):
+ sanitized["type"] = "array"
+ prefix_items = cast("list[Any]", value)
+ if prefix_items:
+ first_item: Any = prefix_items[0]
+ if isinstance(first_item, dict):
+ first_schema = cast("dict[str, Any]", first_item)
+ sanitized["items"] = {"type": first_schema.get("type", "string")}
+ else:
+ sanitized["items"] = {"type": "string"}
+
+ elif key == "properties" and isinstance(value, dict):
+ properties = cast("dict[str, Any]", value)
+ sanitized[key] = {
+ prop_name: _sanitize_schema_for_openai(cast("dict[str, Any]", prop_schema))
+ for prop_name, prop_schema in properties.items()
+ if isinstance(prop_schema, dict)
+ }
+
+ elif key == "items" and isinstance(value, dict):
+ sanitized[key] = _sanitize_schema_for_openai(cast("dict[str, Any]", value))
+
+ elif key in (
+ "type",
+ "description",
+ "enum",
+ "required",
+ "default",
+ "minimum",
+ "maximum",
+ "minItems",
+ "maxItems",
+ ):
+ sanitized[key] = value
+
+ return sanitized or {"type": "object"}
diff --git a/hud/agents/openai_compatible/tools/computer.py b/hud/agents/openai_compatible/tools/computer.py
deleted file mode 100644
index d7e450c89..000000000
--- a/hud/agents/openai_compatible/tools/computer.py
+++ /dev/null
@@ -1,566 +0,0 @@
-"""Agent-side OpenAI-compatible computer tools."""
-
-from __future__ import annotations
-
-import logging
-import re
-from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args
-
-from mcp.types import ImageContent, TextContent
-
-from hud.agents.tools import AgentTool, AgentToolSpec, CallTool, call_tool
-from hud.tools.computer import computer_settings
-from hud.types import MCPToolResult
-
-from .types import OpenAICompatibleToolParam, QwenComputerUseToolParam
-
-if TYPE_CHECKING:
- from openai.types.chat import ChatCompletionToolParam
- from openai.types.shared_params.function_parameters import FunctionParameters
-
- from hud.agents.tools import EnvironmentCapability
-
-logger = logging.getLogger(__name__)
-
-GLM_COORDINATE_SPACE = 999
-
-GLMAction = Literal[
- "left_click",
- "click",
- "right_click",
- "middle_click",
- "hover",
- "left_double_click",
- "left_drag",
- "key",
- "type",
- "scroll",
- "screenshot",
- "WAIT",
-]
-
-VALID_GLM_ACTIONS: set[str] = set(get_args(GLMAction))
-
-GLM_COMPUTER_SPEC = AgentToolSpec(
- api_type="function",
- api_name="computer",
- supported_models=("glm-*",),
-)
-
-QWEN_COMPUTER_SPEC = AgentToolSpec(
- api_type="computer_use",
- api_name="computer_use",
- supported_models=("qwen*",),
-)
-
-GLM_SYSTEM_INSTRUCTIONS = (
- "You are a GUI Agent. Your task is to respond accurately to user requests by using "
- "tools or performing GUI operations until the task is fulfilled. Coordinates are in "
- "thousandths (0-999). Complete tasks autonomously without asking for confirmation. "
- "If a task cannot be completed, explain the failure in your final response."
-)
-
-GLM_COMPUTER_DESCRIPTION = """\
-Use this tool to interact with the computer via GLM's PC action space.
-* Coordinates use a 0-999 normalized scale (thousandths of screen dimensions).
-* Always use valid JSON for function arguments. Do NOT use XML tags.
- Correct: {"action": "left_click", "start_box": "[500, 300]"}
- Wrong: {"action": "left_clickstart_box..."}
-* Available actions:
- - left_click/right_click/middle_click(start_box='[x,y]')
- - hover(start_box='[x,y]'), left_double_click(start_box='[x,y]')
- - left_drag(start_box='[x,y]', end_box='[x,y]')
- - key(keys='ctrl+c'), type(content='text')
- - scroll(start_box='[x,y]', direction='up|down', step=5)
- - screenshot(), WAIT()
-* If a task cannot be completed, explain the failure in your final response.\
-""".strip()
-
-GLM_COMPUTER_PARAMETERS: FunctionParameters = {
- "type": "object",
- "properties": {
- "action": {
- "type": "string",
- "description": (
- "REQUIRED. Action to perform: left_click, right_click, middle_click, "
- "hover, left_double_click, left_drag, key, type, scroll, screenshot, "
- "WAIT"
- ),
- "enum": sorted(VALID_GLM_ACTIONS),
- },
- "start_box": {
- "description": (
- "Position as '[x,y]' string or [x,y] array, coordinates 0-999 normalized"
- ),
- },
- "end_box": {
- "description": "End position for drag as '[x,y]' string or [x,y] array",
- },
- "content": {"type": "string", "description": "Text content to type"},
- "keys": {"description": "Key(s) to press, e.g. 'enter', 'ctrl+c', 'alt+tab'"},
- "direction": {"type": "string", "description": "Scroll direction: 'up' or 'down'"},
- "step": {"type": "integer", "description": "Scroll steps", "default": 5},
- "element_info": {"type": "string", "description": "Optional UI element description"},
- },
- "required": ["action"],
-}
-
-
-class GLMComputerTool(AgentTool[OpenAICompatibleToolParam]):
- """Translate GLM native GUI calls into generic environment computer calls."""
-
- name = "computer"
- capability = "computer"
- ignored_api_types: ClassVar[frozenset[str]] = frozenset({"gui_agent_glm45v"})
-
- @classmethod
- def default_spec(cls, model: str) -> AgentToolSpec | None:
- if GLM_COMPUTER_SPEC.supports_model(model):
- return GLM_COMPUTER_SPEC
- return None
-
- def __init__(
- self,
- *,
- env_tool_name: str,
- spec: AgentToolSpec,
- display_width: int,
- display_height: int,
- ) -> None:
- super().__init__(env_tool_name=env_tool_name, spec=spec)
- self.display_width = display_width
- self.display_height = display_height
-
- @classmethod
- def from_capability(
- cls,
- capability: EnvironmentCapability,
- spec: AgentToolSpec,
- model: str,
- ) -> GLMComputerTool:
- del model
- width, height = _resolution_from_capability(
- capability,
- default_width=computer_settings.GLM_COMPUTER_WIDTH,
- default_height=computer_settings.GLM_COMPUTER_HEIGHT,
- )
- return cls(
- env_tool_name=capability.tool_name,
- spec=spec,
- display_width=width,
- display_height=height,
- )
-
- def to_params(self) -> ChatCompletionToolParam:
- return {
- "type": "function",
- "function": {
- "name": self.name,
- "description": (
- f"{GLM_COMPUTER_DESCRIPTION}\n* The screen's resolution is "
- f"{self.display_width}x{self.display_height}."
- ),
- "parameters": GLM_COMPUTER_PARAMETERS,
- },
- }
-
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- arguments = _fix_glm_xml_args(arguments)
- action = arguments.get("action")
- if not isinstance(action, str):
- return _error_result("'action' is required")
-
- result = MCPToolResult(content=[], isError=False)
- for call in self._env_calls(action, arguments):
- result = await call_tool(caller, self.env_tool_name, call)
- if result.isError:
- return result
-
- if action not in {"screenshot", "WAIT"} and not _has_image(result):
- screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"})
- if not screenshot.isError and screenshot.content:
- result = MCPToolResult(
- content=[*result.content, *screenshot.content],
- isError=result.isError,
- )
- return result
-
- def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
- start = _parse_glm_box(arguments.get("start_box"))
- end = _parse_glm_box(arguments.get("end_box"))
-
- if action == "screenshot":
- return [{"action": "screenshot"}]
- if action == "WAIT":
- return [{"action": "wait", "time": 5000}]
- if action in ("left_click", "click", "right_click", "middle_click"):
- x, y = self._point(start, f"start_box required for {action}")
- button = {
- "left_click": "left",
- "click": "left",
- "right_click": "right",
- "middle_click": "middle",
- }[action]
- return [{"action": "click", "x": x, "y": y, "button": button}]
- if action == "hover":
- x, y = self._point(start, "start_box required for hover")
- return [{"action": "move", "x": x, "y": y}]
- if action == "left_double_click":
- x, y = self._point(start, "start_box required for left_double_click")
- return [{"action": "click", "x": x, "y": y, "button": "left", "pattern": [100]}]
- if action == "left_drag":
- start_x, start_y = self._point(start, "start_box required for left_drag")
- end_x, end_y = self._point(end, "end_box required for left_drag")
- return [
- {
- "action": "drag",
- "path": [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}],
- }
- ]
- if action == "key":
- keys = _parse_glm_keys(arguments.get("keys"))
- if not keys:
- raise ValueError("keys required for key action")
- return [{"action": "press", "keys": keys}]
- if action == "type":
- content = arguments.get("content")
- if not isinstance(content, str) or not content:
- raise ValueError("content required for type")
- return [{"action": "write", "text": content, "enter_after": False}]
- if action == "scroll":
- direction = arguments.get("direction")
- if direction not in {"up", "down"}:
- raise ValueError("direction must be 'up' or 'down'")
- point = start or (GLM_COORDINATE_SPACE // 2, GLM_COORDINATE_SPACE // 2)
- x, y = self._scale_normalized_point(point)
- step = arguments.get("step") or 5
- scroll_y = int(step) * 100 if direction == "down" else -int(step) * 100
- return [{"action": "scroll", "x": x, "y": y, "scroll_y": scroll_y}]
- raise ValueError(f"Unknown action: {action}")
-
- def _point(self, point: tuple[int, int] | None, message: str) -> tuple[int, int]:
- if point is None:
- raise ValueError(message)
- return self._scale_normalized_point(point)
-
- def _scale_normalized_point(self, point: tuple[int, int]) -> tuple[int, int]:
- x, y = point
- scaled_x = round(x / GLM_COORDINATE_SPACE * (self.display_width - 1))
- scaled_y = round(y / GLM_COORDINATE_SPACE * (self.display_height - 1))
- return scaled_x, scaled_y
-
-
-class QwenComputerTool(AgentTool[OpenAICompatibleToolParam]):
- """Translate Qwen computer_use calls into generic environment computer calls."""
-
- name = "computer_use"
- capability = "computer"
-
- @classmethod
- def default_spec(cls, model: str) -> AgentToolSpec | None:
- if QWEN_COMPUTER_SPEC.supports_model(model):
- return QWEN_COMPUTER_SPEC
- return None
-
- def __init__(
- self,
- *,
- env_tool_name: str,
- spec: AgentToolSpec,
- display_width: int,
- display_height: int,
- description: str,
- ) -> None:
- super().__init__(env_tool_name=env_tool_name, spec=spec)
- self.display_width = display_width
- self.display_height = display_height
- self.description = description
-
- @classmethod
- def from_capability(
- cls,
- capability: EnvironmentCapability,
- spec: AgentToolSpec,
- model: str,
- ) -> QwenComputerTool:
- del model
- width, height = _resolution_from_capability(
- capability,
- default_width=computer_settings.QWEN_COMPUTER_WIDTH,
- default_height=computer_settings.QWEN_COMPUTER_HEIGHT,
- )
- return cls(
- env_tool_name=capability.tool_name,
- spec=spec,
- display_width=width,
- display_height=height,
- description=_qwen_description(width, height),
- )
-
- def to_params(self) -> QwenComputerUseToolParam:
- tool: QwenComputerUseToolParam = {
- "type": "computer_use",
- "name": self.name,
- "display_width_px": self.display_width,
- "display_height_px": self.display_height,
- "description": self.description,
- "parameters": QWEN_COMPUTER_PARAMETERS,
- }
- return tool
-
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- action = arguments.get("action")
- if not isinstance(action, str):
- return _error_result("action is required")
- if action == "terminate":
- return _error_result("terminate action is not supported for computer control.")
- if action == "answer":
- return _error_result("answer action is not supported for computer control.")
-
- result = MCPToolResult(content=[], isError=False)
- for call in self._env_calls(action, arguments):
- result = await call_tool(caller, self.env_tool_name, call)
- if result.isError:
- return result
-
- if action not in {"screenshot", "wait"} and not _has_image(result):
- screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"})
- if not screenshot.isError and screenshot.content:
- result = MCPToolResult(
- content=[*result.content, *screenshot.content],
- isError=result.isError,
- )
- return result
-
- def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
- coordinate = _parse_qwen_coordinate(arguments.get("coordinate"))
- if action == "screenshot":
- return [{"action": "screenshot"}]
- if action in {"left_click", "right_click", "middle_click"}:
- x, y = _required_coordinate(coordinate, action)
- button = {"left_click": "left", "right_click": "right", "middle_click": "middle"}[
- action
- ]
- return [{"action": "click", "x": x, "y": y, "button": button}]
- if action == "double_click":
- x, y = _required_coordinate(coordinate, action)
- return [{"action": "click", "x": x, "y": y, "pattern": [100]}]
- if action == "triple_click":
- x, y = _required_coordinate(coordinate, action)
- return [{"action": "click", "x": x, "y": y, "pattern": [100, 100]}]
- if action == "mouse_move":
- x, y = _required_coordinate(coordinate, action)
- return [{"action": "move", "x": x, "y": y}]
- if action == "type":
- text = arguments.get("text")
- if not isinstance(text, str):
- raise ValueError("text is required for type")
- return [{"action": "write", "text": text}]
- if action == "key":
- keys = arguments.get("keys")
- if not isinstance(keys, list):
- raise ValueError("keys is required for key")
- return [{"action": "press", "keys": keys}]
- if action in {"scroll", "hscroll"}:
- pixels = arguments.get("pixels")
- if not isinstance(pixels, int | float):
- raise ValueError("pixels is required for scroll")
- call: dict[str, Any] = {"action": "scroll"}
- if coordinate is not None:
- call.update({"x": coordinate[0], "y": coordinate[1]})
- if action == "scroll":
- call["scroll_y"] = -int(pixels)
- else:
- call["scroll_x"] = int(pixels)
- return [call]
- if action == "left_click_drag":
- x, y = _required_coordinate(coordinate, action)
- return [
- {"action": "mouse_down", "button": "left"},
- {"action": "move", "x": x, "y": y},
- {"action": "mouse_up", "button": "left"},
- ]
- if action == "wait":
- time = arguments.get("time")
- if not isinstance(time, int | float):
- raise ValueError("time is required for wait")
- if time < 0:
- raise ValueError("time must be non-negative")
- return [{"action": "wait", "time": int(time * 1000)}]
- raise ValueError(f"Invalid action: {action}")
-
-
-QWEN_COMPUTER_PARAMETERS: FunctionParameters = {
- "properties": {
- "action": {
- "description": """
-The action to perform. The available actions are:
-* `key`: Performs key down presses on the arguments passed in order, then performs
-key releases in reverse order.
-* `type`: Type a string of text on the keyboard.
-* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
-* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate.
-* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate.
-* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate.
-* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate.
-* `double_click`: Double-click the left mouse button.
-* `triple_click`: Triple-click the left mouse button.
-* `scroll`: Performs a vertical scroll.
-* `hscroll`: Performs a horizontal scroll.
-* `wait`: Wait specified seconds for the change to happen.
-""".strip(),
- "enum": [
- "key",
- "type",
- "mouse_move",
- "left_click",
- "left_click_drag",
- "right_click",
- "middle_click",
- "double_click",
- "triple_click",
- "scroll",
- "hscroll",
- "wait",
- ],
- "type": "string",
- },
- "keys": {"description": "Required only by `action=key`.", "type": "array"},
- "text": {
- "description": "Required only by `action=type`.",
- "type": "string",
- },
- "coordinate": {
- "description": "(x, y) pixel coordinate to interact with.",
- "type": "array",
- },
- "pixels": {
- "description": "Scroll amount. Positive vertical values scroll up.",
- "type": "number",
- },
- "time": {
- "description": "Seconds to wait. Required only by `action=wait`.",
- "type": "number",
- },
- },
- "required": ["action"],
- "type": "object",
-}
-
-
-def _resolution_from_capability(
- capability: EnvironmentCapability,
- *,
- default_width: int,
- default_height: int,
-) -> tuple[int, int]:
- metadata_resolution = capability.metadata.get("resolution", {})
- if not isinstance(metadata_resolution, dict):
- metadata_resolution = {}
- tool_resolution = (capability.tool.meta or {}).get("resolution", {})
- if not isinstance(tool_resolution, dict):
- tool_resolution = {}
- width = int(metadata_resolution.get("width") or tool_resolution.get("width") or default_width)
- height = int(
- metadata_resolution.get("height") or tool_resolution.get("height") or default_height
- )
- return width, height
-
-
-def _qwen_description(width: int, height: int) -> str:
- return f"""
-Use a mouse and keyboard to interact with a computer, and take screenshots.
-* This is an interface to a desktop GUI. You do not have access to a terminal or
-applications menu. You must click on desktop icons to start applications.
-* Some applications may take time to start or process actions, so you may need to
-wait and take successive screenshots to see the results of your actions.
-* The screen's resolution is {width}x{height}.
-* Whenever you intend to move the cursor to click on an element like an icon, you
-should consult a screenshot to determine the coordinates of the element before
-moving the cursor.
-* Make sure to click buttons, links, and icons with the cursor tip in the center.
-""".strip()
-
-
-def _parse_glm_box(box: Any) -> tuple[int, int] | None:
- if box is None:
- return None
- if isinstance(box, str):
- match = re.match(r"\[?\s*(\d+)\s*,\s*(\d+)\s*\]?", box.strip())
- if match:
- return int(match.group(1)), int(match.group(2))
- return None
- if isinstance(box, list):
- if len(box) == 1 and isinstance(box[0], list):
- box = box[0]
- if len(box) >= 2:
- try:
- return int(box[0]), int(box[1])
- except (TypeError, ValueError):
- return None
- return None
-
-
-def _parse_glm_keys(keys: Any) -> list[str]:
- if not keys:
- return []
- if isinstance(keys, list):
- return [str(key).strip().lower() for key in keys]
- return [key.strip().lower() for key in str(keys).split("+") if key.strip()]
-
-
-def _fix_glm_xml_args(args: dict[str, Any]) -> dict[str, Any]:
- fixed: dict[str, Any] = {}
- for key, value in args.items():
- if not isinstance(value, str) or not re.search(r"?arg_", value):
- fixed[key] = value
- continue
-
- main_value = re.split(r"?arg_", value, maxsplit=1)[0].strip()
- if main_value:
- fixed[key] = main_value
-
- matches = re.findall(r"(\w+)\s*([^\"<]+)", value)
- for arg_name, arg_val in matches:
- if arg_name and arg_val:
- fixed[arg_name.strip()] = arg_val.strip()
-
- if not main_value and not matches:
- fixed[key] = value
- logger.warning("Fixed GLM XML args: %s -> %s", args, fixed)
- return fixed
-
-
-def _parse_qwen_coordinate(coordinate: Any) -> tuple[int, int] | None:
- if isinstance(coordinate, list | tuple) and len(coordinate) >= 2:
- try:
- return int(coordinate[0]), int(coordinate[1])
- except (TypeError, ValueError):
- return None
- return None
-
-
-def _required_coordinate(coordinate: tuple[int, int] | None, action: str) -> tuple[int, int]:
- if coordinate is None:
- raise ValueError(f"coordinate is required for {action}")
- return coordinate
-
-
-def _has_image(result: MCPToolResult) -> bool:
- return any(isinstance(block, ImageContent) for block in result.content)
-
-
-def _error_result(message: str) -> MCPToolResult:
- return MCPToolResult(content=[TextContent(type="text", text=message)], isError=True)
-
-
-__all__ = [
- "GLM_COMPUTER_SPEC",
- "GLM_COORDINATE_SPACE",
- "QWEN_COMPUTER_SPEC",
- "VALID_GLM_ACTIONS",
- "GLMComputerTool",
- "QwenComputerTool",
- "_fix_glm_xml_args",
- "_parse_glm_box",
-]
diff --git a/hud/agents/openai_compatible/tools/filesystem.py b/hud/agents/openai_compatible/tools/filesystem.py
index 4f5ba57f2..a09ed988c 100644
--- a/hud/agents/openai_compatible/tools/filesystem.py
+++ b/hud/agents/openai_compatible/tools/filesystem.py
@@ -4,84 +4,16 @@
from typing import TYPE_CHECKING, ClassVar
-from hud.agents.tools import AgentTool, AgentToolSpec, GroupedCapabilityMixin
+from hud.agents.tools import AgentToolSpec, GroupedCapabilityMixin
-from .types import OpenAICompatibleToolParam
+from .base import OpenAICompatibleTool
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_parameters import FunctionParameters
-READ_PARAMETERS: FunctionParameters = {
- "type": "object",
- "properties": {
- "filePath": {
- "type": "string",
- "description": "Absolute path to the file to read.",
- },
- "offset": {
- "type": "integer",
- "description": "0-based line offset to start reading from.",
- },
- "limit": {
- "type": "integer",
- "description": "Maximum number of lines to read.",
- },
- },
- "required": ["filePath"],
-}
-
-GREP_PARAMETERS: FunctionParameters = {
- "type": "object",
- "properties": {
- "pattern": {
- "type": "string",
- "description": "Regular expression pattern to search for.",
- },
- "path": {
- "type": "string",
- "description": "Directory to search in.",
- },
- "include": {
- "type": "string",
- "description": "Glob pattern for files to include.",
- },
- },
- "required": ["pattern"],
-}
-
-GLOB_PARAMETERS: FunctionParameters = {
- "type": "object",
- "properties": {
- "pattern": {
- "type": "string",
- "description": "Glob pattern to match.",
- },
- "path": {
- "type": "string",
- "description": "Directory to search from.",
- },
- },
- "required": ["pattern"],
-}
-
-LIST_PARAMETERS: FunctionParameters = {
- "type": "object",
- "properties": {
- "path": {
- "type": "string",
- "description": "Directory to list.",
- },
- "ignore": {
- "type": "array",
- "items": {"type": "string"},
- "description": "Glob patterns to ignore.",
- },
- },
-}
-
-class FilesystemTool(GroupedCapabilityMixin, AgentTool[OpenAICompatibleToolParam]):
+class _FilesystemTool(GroupedCapabilityMixin, OpenAICompatibleTool):
"""Function tool backed by a HUD filesystem environment tool."""
description: ClassVar[str]
@@ -104,54 +36,101 @@ def to_params(self) -> ChatCompletionToolParam:
}
-class ReadTool(FilesystemTool):
+class ReadTool(_FilesystemTool):
"""Expose a read function over the environment read tool."""
name = "read"
capability = "filesystem"
env_tool_names = ("read",)
description = "Reads a file from the local filesystem. Use offset and limit for pagination."
- parameters: ClassVar[FunctionParameters] = READ_PARAMETERS
+ parameters: ClassVar[FunctionParameters] = {
+ "type": "object",
+ "properties": {
+ "filePath": {
+ "type": "string",
+ "description": "Absolute path to the file to read.",
+ },
+ "offset": {
+ "type": "integer",
+ "description": "0-based line offset to start reading from.",
+ },
+ "limit": {
+ "type": "integer",
+ "description": "Maximum number of lines to read.",
+ },
+ },
+ "required": ["filePath"],
+ }
-class GrepTool(FilesystemTool):
+class GrepTool(_FilesystemTool):
"""Expose a grep function over the environment grep tool."""
name = "grep"
capability = "filesystem"
env_tool_names = ("grep",)
description = "Searches file contents using a regular expression and returns matching lines."
- parameters: ClassVar[FunctionParameters] = GREP_PARAMETERS
+ parameters: ClassVar[FunctionParameters] = {
+ "type": "object",
+ "properties": {
+ "pattern": {
+ "type": "string",
+ "description": "Regular expression pattern to search for.",
+ },
+ "path": {
+ "type": "string",
+ "description": "Directory to search in.",
+ },
+ "include": {
+ "type": "string",
+ "description": "Glob pattern for files to include.",
+ },
+ },
+ "required": ["pattern"],
+ }
-class GlobTool(FilesystemTool):
+class GlobTool(_FilesystemTool):
"""Expose a glob function over the environment glob tool."""
name = "glob"
capability = "filesystem"
env_tool_names = ("glob",)
description = "Finds files matching a glob pattern."
- parameters: ClassVar[FunctionParameters] = GLOB_PARAMETERS
+ parameters: ClassVar[FunctionParameters] = {
+ "type": "object",
+ "properties": {
+ "pattern": {
+ "type": "string",
+ "description": "Glob pattern to match.",
+ },
+ "path": {
+ "type": "string",
+ "description": "Directory to search from.",
+ },
+ },
+ "required": ["pattern"],
+ }
-class ListTool(FilesystemTool):
+class ListTool(_FilesystemTool):
"""Expose a list function over the environment list tool."""
name = "list"
capability = "filesystem"
env_tool_names = ("list",)
description = "Lists files and directories in a given path."
- parameters: ClassVar[FunctionParameters] = LIST_PARAMETERS
-
-
-__all__ = [
- "GLOB_PARAMETERS",
- "GREP_PARAMETERS",
- "LIST_PARAMETERS",
- "READ_PARAMETERS",
- "FilesystemTool",
- "GlobTool",
- "GrepTool",
- "ListTool",
- "ReadTool",
-]
+ parameters: ClassVar[FunctionParameters] = {
+ "type": "object",
+ "properties": {
+ "path": {
+ "type": "string",
+ "description": "Directory to list.",
+ },
+ "ignore": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Glob patterns to ignore.",
+ },
+ },
+ }
diff --git a/hud/agents/openai_compatible/tools/glm_computer.py b/hud/agents/openai_compatible/tools/glm_computer.py
new file mode 100644
index 000000000..463860a19
--- /dev/null
+++ b/hud/agents/openai_compatible/tools/glm_computer.py
@@ -0,0 +1,294 @@
+"""Agent-side GLM computer tool for OpenAI-compatible chat models."""
+
+from __future__ import annotations
+
+import logging
+import re
+from typing import TYPE_CHECKING, Any, Literal, cast, get_args
+
+from hud.agents.tools import AgentToolSpec
+from hud.agents.tools.computer import (
+ computer_error_result,
+ computer_tool_info,
+ execute_computer_calls,
+)
+
+from .base import OpenAICompatibleTool
+from .settings import openai_compatible_tool_settings
+
+if TYPE_CHECKING:
+ from openai.types.chat import ChatCompletionToolParam
+ from openai.types.shared_params.function_parameters import FunctionParameters
+
+ from hud.agents.tools import EnvironmentCapability
+ from hud.agents.tools.base import CallTool
+ from hud.types import MCPToolResult
+
+logger = logging.getLogger(__name__)
+
+GLM_COORDINATE_SPACE = 999
+
+GLMAction = Literal[
+ "left_click",
+ "click",
+ "right_click",
+ "middle_click",
+ "hover",
+ "left_double_click",
+ "left_drag",
+ "key",
+ "type",
+ "scroll",
+ "screenshot",
+ "WAIT",
+]
+
+VALID_GLM_ACTIONS: set[str] = set(get_args(GLMAction))
+
+GLM_COMPUTER_SPEC = AgentToolSpec(
+ api_type="function",
+ api_name="computer",
+ supported_models=("glm-*",),
+)
+
+GLM_SYSTEM_INSTRUCTIONS = (
+ "You are a GUI Agent. Your task is to respond accurately to user requests by using "
+ "tools or performing GUI operations until the task is fulfilled. Coordinates are in "
+ "thousandths (0-999). Complete tasks autonomously without asking for confirmation. "
+ "If a task cannot be completed, explain the failure in your final response."
+)
+
+GLM_COMPUTER_DESCRIPTION = """\
+Use this tool to interact with the computer via GLM's PC action space.
+* Coordinates use a 0-999 normalized scale (thousandths of screen dimensions).
+* Always use valid JSON for function arguments. Do NOT use XML tags.
+ Correct: {"action": "left_click", "start_box": "[500, 300]"}
+ Wrong: {"action": "left_clickstart_box..."}
+* Available actions:
+ - left_click/right_click/middle_click(start_box='[x,y]')
+ - hover(start_box='[x,y]'), left_double_click(start_box='[x,y]')
+ - left_drag(start_box='[x,y]', end_box='[x,y]')
+ - key(keys='ctrl+c'), type(content='text')
+ - scroll(start_box='[x,y]', direction='up|down', step=5)
+ - screenshot(), WAIT()
+* If a task cannot be completed, explain the failure in your final response.\
+""".strip()
+
+GLM_COMPUTER_PARAMETERS: FunctionParameters = {
+ "type": "object",
+ "properties": {
+ "action": {
+ "type": "string",
+ "description": (
+ "REQUIRED. Action to perform: left_click, right_click, middle_click, "
+ "hover, left_double_click, left_drag, key, type, scroll, screenshot, "
+ "WAIT"
+ ),
+ "enum": sorted(VALID_GLM_ACTIONS),
+ },
+ "start_box": {
+ "description": (
+ "Position as '[x,y]' string or [x,y] array, coordinates 0-999 normalized"
+ ),
+ },
+ "end_box": {
+ "description": "End position for drag as '[x,y]' string or [x,y] array",
+ },
+ "content": {"type": "string", "description": "Text content to type"},
+ "keys": {"description": "Key(s) to press, e.g. 'enter', 'ctrl+c', 'alt+tab'"},
+ "direction": {"type": "string", "description": "Scroll direction: 'up' or 'down'"},
+ "step": {"type": "integer", "description": "Scroll steps", "default": 5},
+ "element_info": {"type": "string", "description": "Optional UI element description"},
+ },
+ "required": ["action"],
+}
+
+
+class GLMComputerTool(OpenAICompatibleTool):
+ """Translate GLM native GUI calls into generic environment computer calls."""
+
+ name = "computer"
+ capability = "computer"
+
+ @classmethod
+ def default_spec(cls, model: str) -> AgentToolSpec | None:
+ if GLM_COMPUTER_SPEC.supports_model(model):
+ return GLM_COMPUTER_SPEC
+ return None
+
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
+ spec: AgentToolSpec,
+ display_width: int,
+ display_height: int,
+ coordinate_space: int | None,
+ ) -> None:
+ super().__init__(env_tool_name=env_tool_name, spec=spec)
+ self.display_width = display_width
+ self.display_height = display_height
+ self.coordinate_space = coordinate_space
+
+ @classmethod
+ def from_capability(
+ cls,
+ capability: EnvironmentCapability,
+ model: str,
+ ) -> GLMComputerTool | None:
+ spec = cls.default_spec(model)
+ if spec is None:
+ return None
+
+ computer_info = computer_tool_info(
+ capability.tool,
+ default_width=openai_compatible_tool_settings.GLM_COMPUTER_WIDTH,
+ default_height=openai_compatible_tool_settings.GLM_COMPUTER_HEIGHT,
+ )
+ return cls(
+ env_tool_name=capability.tool_name,
+ spec=spec,
+ display_width=computer_info.display_width,
+ display_height=computer_info.display_height,
+ coordinate_space=computer_info.coordinate_space,
+ )
+
+ def to_params(self) -> ChatCompletionToolParam:
+ return {
+ "type": "function",
+ "function": {
+ "name": self.name,
+ "description": (
+ f"{GLM_COMPUTER_DESCRIPTION}\n* The screen's resolution is "
+ f"{self.display_width}x{self.display_height}."
+ ),
+ "parameters": GLM_COMPUTER_PARAMETERS,
+ },
+ }
+
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ arguments = _normalize_glm_args(arguments)
+ action = arguments.get("action")
+ if not isinstance(action, str):
+ return computer_error_result("'action' is required")
+
+ return await execute_computer_calls(
+ call_tool,
+ env_tool_name=self.env_tool_name,
+ calls=self._env_calls(action, arguments),
+ ensure_screenshot=action not in {"screenshot", "WAIT"},
+ )
+
+ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
+ start = _parse_glm_box(arguments.get("start_box"))
+ end = _parse_glm_box(arguments.get("end_box"))
+
+ if action == "screenshot":
+ return [{"action": "screenshot"}]
+ if action == "WAIT":
+ return [{"action": "wait", "time": 5000}]
+ if action in ("left_click", "click", "right_click", "middle_click"):
+ x, y = self._point(start, f"start_box required for {action}")
+ button = {
+ "left_click": "left",
+ "click": "left",
+ "right_click": "right",
+ "middle_click": "middle",
+ }[action]
+ return [{"action": "click", "x": x, "y": y, "button": button}]
+ if action == "hover":
+ x, y = self._point(start, "start_box required for hover")
+ return [{"action": "move", "x": x, "y": y}]
+ if action == "left_double_click":
+ x, y = self._point(start, "start_box required for left_double_click")
+ return [{"action": "click", "x": x, "y": y, "button": "left", "pattern": [100]}]
+ if action == "left_drag":
+ start_x, start_y = self._point(start, "start_box required for left_drag")
+ end_x, end_y = self._point(end, "end_box required for left_drag")
+ return [
+ {
+ "action": "drag",
+ "path": [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}],
+ }
+ ]
+ if action == "key":
+ raw_keys = arguments.get("keys")
+ if isinstance(raw_keys, list):
+ keys = [str(key).strip().lower() for key in cast("list[Any]", raw_keys)]
+ else:
+ keys = [
+ key.strip().lower() for key in str(raw_keys or "").split("+") if key.strip()
+ ]
+ if not keys:
+ raise ValueError("keys required for key action")
+ return [{"action": "press", "keys": keys}]
+ if action == "type":
+ content = arguments.get("content")
+ if not isinstance(content, str) or not content:
+ raise ValueError("content required for type")
+ return [{"action": "write", "text": content, "enter_after": False}]
+ if action == "scroll":
+ direction = arguments.get("direction")
+ if direction not in {"up", "down"}:
+ raise ValueError("direction must be 'up' or 'down'")
+ point = start or (GLM_COORDINATE_SPACE // 2, GLM_COORDINATE_SPACE // 2)
+ x, y = self._scale_normalized_point(point)
+ step = arguments.get("step") or 5
+ scroll_y = int(step) * 100 if direction == "down" else -int(step) * 100
+ return [{"action": "scroll", "x": x, "y": y, "scroll_y": scroll_y}]
+ raise ValueError(f"Unknown action: {action}")
+
+ def _point(self, point: tuple[int, int] | None, message: str) -> tuple[int, int]:
+ if point is None:
+ raise ValueError(message)
+ return self._scale_normalized_point(point)
+
+ def _scale_normalized_point(self, point: tuple[int, int]) -> tuple[int, int]:
+ if self.coordinate_space == GLM_COORDINATE_SPACE:
+ return point
+ x, y = point
+ scaled_x = round(x / GLM_COORDINATE_SPACE * (self.display_width - 1))
+ scaled_y = round(y / GLM_COORDINATE_SPACE * (self.display_height - 1))
+ return scaled_x, scaled_y
+
+
+def _parse_glm_box(box: Any) -> tuple[int, int] | None:
+ if box is None:
+ return None
+ if isinstance(box, str):
+ match = re.match(r"\[?\s*(\d+)\s*,\s*(\d+)\s*\]?", box.strip())
+ if match:
+ return int(match.group(1)), int(match.group(2))
+ return None
+ if isinstance(box, list):
+ nested = cast("list[Any]", box)
+ if len(nested) == 1 and isinstance(nested[0], list):
+ nested = cast("list[Any]", nested[0])
+ if len(nested) >= 2:
+ try:
+ return int(nested[0]), int(nested[1])
+ except (TypeError, ValueError):
+ return None
+ return None
+
+
+def _normalize_glm_args(args: dict[str, Any]) -> dict[str, Any]:
+ fixed: dict[str, Any] = {}
+ for key, value in args.items():
+ if not isinstance(value, str) or not re.search(r"?arg_", value):
+ fixed[key] = value
+ continue
+
+ main_value = re.split(r"?arg_", value, maxsplit=1)[0].strip()
+ if main_value:
+ fixed[key] = main_value
+
+ matches = re.findall(r"(\w+)\s*([^\"<]+)", value)
+ for arg_name, arg_val in matches:
+ if arg_name and arg_val:
+ fixed[arg_name.strip()] = arg_val.strip()
+
+ if not main_value and not matches:
+ fixed[key] = value
+ logger.warning("Fixed GLM XML args: %s -> %s", args, fixed)
+ return fixed
diff --git a/hud/agents/openai_compatible/tools/qwen_computer.py b/hud/agents/openai_compatible/tools/qwen_computer.py
new file mode 100644
index 000000000..425e5f844
--- /dev/null
+++ b/hud/agents/openai_compatible/tools/qwen_computer.py
@@ -0,0 +1,266 @@
+"""Agent-side Qwen computer tool for OpenAI-compatible chat models."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
+
+from hud.agents.tools import AgentToolSpec
+from hud.agents.tools.computer import (
+ computer_error_result,
+ computer_tool_info,
+ execute_computer_calls,
+)
+
+from .base import OpenAICompatibleTool
+from .settings import openai_compatible_tool_settings
+
+if TYPE_CHECKING:
+ from openai.types.shared_params.function_parameters import FunctionParameters
+
+ from hud.agents.tools import EnvironmentCapability
+ from hud.agents.tools.base import CallTool
+ from hud.types import MCPToolResult
+
+QWEN_COMPUTER_SPEC = AgentToolSpec(
+ api_type="computer_use",
+ api_name="computer_use",
+ supported_models=("qwen*",),
+)
+
+
+class QwenComputerUseToolParam(TypedDict):
+ """Qwen's OpenAI-compatible computer_use extension."""
+
+ type: Literal["computer_use"]
+ name: str
+ display_width_px: int
+ display_height_px: int
+ description: str
+ parameters: FunctionParameters
+
+
+class QwenComputerTool(OpenAICompatibleTool):
+ """Translate Qwen computer_use calls into generic environment computer calls."""
+
+ name = "computer_use"
+ capability = "computer"
+
+ @classmethod
+ def default_spec(cls, model: str) -> AgentToolSpec | None:
+ if QWEN_COMPUTER_SPEC.supports_model(model):
+ return QWEN_COMPUTER_SPEC
+ return None
+
+ def __init__(
+ self,
+ *,
+ env_tool_name: str,
+ spec: AgentToolSpec,
+ display_width: int,
+ display_height: int,
+ description: str,
+ ) -> None:
+ super().__init__(env_tool_name=env_tool_name, spec=spec)
+ self.display_width = display_width
+ self.display_height = display_height
+ self.description = description
+
+ @classmethod
+ def from_capability(
+ cls,
+ capability: EnvironmentCapability,
+ model: str,
+ ) -> QwenComputerTool | None:
+ spec = cls.default_spec(model)
+ if spec is None:
+ return None
+
+ computer_info = computer_tool_info(
+ capability.tool,
+ default_width=openai_compatible_tool_settings.QWEN_COMPUTER_WIDTH,
+ default_height=openai_compatible_tool_settings.QWEN_COMPUTER_HEIGHT,
+ )
+ return cls(
+ env_tool_name=capability.tool_name,
+ spec=spec,
+ display_width=computer_info.display_width,
+ display_height=computer_info.display_height,
+ description=_qwen_description(
+ computer_info.display_width, computer_info.display_height
+ ),
+ )
+
+ def to_params(self) -> QwenComputerUseToolParam:
+ tool: QwenComputerUseToolParam = {
+ "type": "computer_use",
+ "name": self.name,
+ "display_width_px": self.display_width,
+ "display_height_px": self.display_height,
+ "description": self.description,
+ "parameters": QWEN_COMPUTER_PARAMETERS,
+ }
+ return tool
+
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ action = arguments.get("action")
+ if not isinstance(action, str):
+ return computer_error_result("action is required")
+ if action == "terminate":
+ return computer_error_result("terminate action is not supported for computer control.")
+ if action == "answer":
+ return computer_error_result("answer action is not supported for computer control.")
+
+ return await execute_computer_calls(
+ call_tool,
+ env_tool_name=self.env_tool_name,
+ calls=self._env_calls(action, arguments),
+ ensure_screenshot=action not in {"screenshot", "wait"},
+ )
+
+ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
+ coordinate = _parse_qwen_coordinate(arguments.get("coordinate"))
+ if action == "screenshot":
+ return [{"action": "screenshot"}]
+ if action in {"left_click", "right_click", "middle_click"}:
+ x, y = _required_coordinate(coordinate, action)
+ button = {"left_click": "left", "right_click": "right", "middle_click": "middle"}[
+ action
+ ]
+ return [{"action": "click", "x": x, "y": y, "button": button}]
+ if action == "double_click":
+ x, y = _required_coordinate(coordinate, action)
+ return [{"action": "click", "x": x, "y": y, "pattern": [100]}]
+ if action == "triple_click":
+ x, y = _required_coordinate(coordinate, action)
+ return [{"action": "click", "x": x, "y": y, "pattern": [100, 100]}]
+ if action == "mouse_move":
+ x, y = _required_coordinate(coordinate, action)
+ return [{"action": "move", "x": x, "y": y}]
+ if action == "type":
+ text = arguments.get("text")
+ if not isinstance(text, str):
+ raise ValueError("text is required for type")
+ return [{"action": "write", "text": text}]
+ if action == "key":
+ keys = arguments.get("keys")
+ if not isinstance(keys, list):
+ raise ValueError("keys is required for key")
+ return [{"action": "press", "keys": keys}]
+ if action in {"scroll", "hscroll"}:
+ pixels = arguments.get("pixels")
+ if not isinstance(pixels, int | float):
+ raise ValueError("pixels is required for scroll")
+ call: dict[str, Any] = {"action": "scroll"}
+ if coordinate is not None:
+ call.update({"x": coordinate[0], "y": coordinate[1]})
+ if action == "scroll":
+ call["scroll_y"] = -int(pixels)
+ else:
+ call["scroll_x"] = int(pixels)
+ return [call]
+ if action == "left_click_drag":
+ x, y = _required_coordinate(coordinate, action)
+ return [
+ {"action": "mouse_down", "button": "left"},
+ {"action": "move", "x": x, "y": y},
+ {"action": "mouse_up", "button": "left"},
+ ]
+ if action == "wait":
+ time = arguments.get("time")
+ if not isinstance(time, int | float):
+ raise ValueError("time is required for wait")
+ if time < 0:
+ raise ValueError("time must be non-negative")
+ return [{"action": "wait", "time": int(time * 1000)}]
+ raise ValueError(f"Invalid action: {action}")
+
+
+QWEN_COMPUTER_PARAMETERS: FunctionParameters = {
+ "properties": {
+ "action": {
+ "description": """
+The action to perform. The available actions are:
+* `key`: Performs key down presses on the arguments passed in order, then performs
+key releases in reverse order.
+* `type`: Type a string of text on the keyboard.
+* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
+* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate.
+* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate.
+* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate.
+* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate.
+* `double_click`: Double-click the left mouse button.
+* `triple_click`: Triple-click the left mouse button.
+* `scroll`: Performs a vertical scroll.
+* `hscroll`: Performs a horizontal scroll.
+* `wait`: Wait specified seconds for the change to happen.
+""".strip(),
+ "enum": [
+ "key",
+ "type",
+ "mouse_move",
+ "left_click",
+ "left_click_drag",
+ "right_click",
+ "middle_click",
+ "double_click",
+ "triple_click",
+ "scroll",
+ "hscroll",
+ "wait",
+ ],
+ "type": "string",
+ },
+ "keys": {"description": "Required only by `action=key`.", "type": "array"},
+ "text": {
+ "description": "Required only by `action=type`.",
+ "type": "string",
+ },
+ "coordinate": {
+ "description": "(x, y) pixel coordinate to interact with.",
+ "type": "array",
+ },
+ "pixels": {
+ "description": "Scroll amount. Positive vertical values scroll up.",
+ "type": "number",
+ },
+ "time": {
+ "description": "Seconds to wait. Required only by `action=wait`.",
+ "type": "number",
+ },
+ },
+ "required": ["action"],
+ "type": "object",
+}
+
+
+def _qwen_description(width: int, height: int) -> str:
+ return f"""
+Use a mouse and keyboard to interact with a computer, and take screenshots.
+* This is an interface to a desktop GUI. You do not have access to a terminal or
+applications menu. You must click on desktop icons to start applications.
+* Some applications may take time to start or process actions, so you may need to
+wait and take successive screenshots to see the results of your actions.
+* The screen's resolution is {width}x{height}.
+* Whenever you intend to move the cursor to click on an element like an icon, you
+should consult a screenshot to determine the coordinates of the element before
+moving the cursor.
+* Make sure to click buttons, links, and icons with the cursor tip in the center.
+""".strip()
+
+
+def _parse_qwen_coordinate(coordinate: Any) -> tuple[int, int] | None:
+ if not isinstance(coordinate, list | tuple):
+ return None
+ coord = cast("list[Any] | tuple[Any, ...]", coordinate)
+ if len(coord) < 2:
+ return None
+ try:
+ return int(coord[0]), int(coord[1])
+ except (TypeError, ValueError):
+ return None
+
+
+def _required_coordinate(coordinate: tuple[int, int] | None, action: str) -> tuple[int, int]:
+ if coordinate is None:
+ raise ValueError(f"coordinate is required for {action}")
+ return coordinate
diff --git a/hud/agents/openai_compatible/tools/settings.py b/hud/agents/openai_compatible/tools/settings.py
new file mode 100644
index 000000000..8ec3dbe71
--- /dev/null
+++ b/hud/agents/openai_compatible/tools/settings.py
@@ -0,0 +1,36 @@
+"""OpenAI-compatible native tool settings owned by the agent."""
+
+from __future__ import annotations
+
+from pydantic import Field
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+
+class OpenAICompatibleToolSettings(BaseSettings):
+ """Provider defaults for OpenAI-compatible agent-owned native tools."""
+
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="allow")
+
+ GLM_COMPUTER_WIDTH: int = Field(
+ default=1024,
+ description="Default GLM computer-use display width",
+ validation_alias="GLM_COMPUTER_WIDTH",
+ )
+ GLM_COMPUTER_HEIGHT: int = Field(
+ default=768,
+ description="Default GLM computer-use display height",
+ validation_alias="GLM_COMPUTER_HEIGHT",
+ )
+ QWEN_COMPUTER_WIDTH: int = Field(
+ default=700,
+ description="Default Qwen computer-use display width",
+ validation_alias="QWEN_COMPUTER_WIDTH",
+ )
+ QWEN_COMPUTER_HEIGHT: int = Field(
+ default=448,
+ description="Default Qwen computer-use display height",
+ validation_alias="QWEN_COMPUTER_HEIGHT",
+ )
+
+
+openai_compatible_tool_settings = OpenAICompatibleToolSettings()
diff --git a/hud/agents/openai_compatible/tools/types.py b/hud/agents/openai_compatible/tools/types.py
deleted file mode 100644
index 2bded858a..000000000
--- a/hud/agents/openai_compatible/tools/types.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Type definitions for OpenAI-compatible chat tools."""
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict
-
-if TYPE_CHECKING:
- from openai.types.chat import ChatCompletionToolParam
- from openai.types.shared_params.function_parameters import FunctionParameters
-
-
-class QwenComputerUseToolParam(TypedDict):
- """Qwen's OpenAI-compatible computer_use extension."""
-
- type: Literal["computer_use"]
- name: str
- display_width_px: int
- display_height_px: int
- description: str
- parameters: FunctionParameters
-
-
-OpenAICompatibleToolParam: TypeAlias = "ChatCompletionToolParam | QwenComputerUseToolParam"
-
-
-__all__ = ["OpenAICompatibleToolParam", "QwenComputerUseToolParam"]
diff --git a/hud/agents/resolver.py b/hud/agents/resolver.py
deleted file mode 100644
index ae9bd8b89..000000000
--- a/hud/agents/resolver.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""Model resolution - maps model strings to agent classes."""
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any
-
-if TYPE_CHECKING:
- from hud.agents.base import MCPAgent
-
-__all__ = ["resolve_cls"]
-
-_models_cache: list[dict[str, Any]] | None = None
-
-
-def _fetch_gateway_models() -> list[dict[str, Any]]:
- """Fetch available models from HUD API (cached)."""
- global _models_cache
- if _models_cache is not None:
- return _models_cache
-
- import httpx
-
- from hud.settings import settings
-
- if not settings.api_key:
- return []
-
- try:
- resp = httpx.get(
- f"{settings.hud_api_url}/models/",
- headers={"Authorization": f"Bearer {settings.api_key}"},
- timeout=10.0,
- )
- resp.raise_for_status()
- data = resp.json()
- models = data.get("models") or []
- _models_cache = models
- return models
- except Exception:
- return []
-
-
-def resolve_cls(model: str) -> tuple[type[MCPAgent], dict[str, Any] | None]:
- """Resolve model string to (agent_class, gateway_info).
-
- Returns:
- (agent_class, None) for known AgentTypes
- (agent_class, gateway_model_info) for gateway models
- """
- from hud.types import AgentType
-
- # Known AgentType → no gateway info
- try:
- return AgentType(model).cls, None
- except ValueError:
- pass
-
- # Gateway lookup
- for m in _fetch_gateway_models():
- if model in (m.get("id"), m.get("name"), m.get("model_name")):
- agent_str = m.get("sdk_agent_type") or m["provider"]["default_sdk_agent_type"]
- if agent_str == "operator":
- raise ValueError(
- "Operator agent is no longer supported; use openai with a supported "
- "OpenAI computer model."
- )
- if agent_str == "gemini_cua":
- raise ValueError(
- "Gemini CUA agent is no longer supported; use gemini with a supported "
- "Gemini computer-use model."
- )
- return AgentType(agent_str).cls, m
-
- raise ValueError(f"Model '{model}' not found")
diff --git a/hud/agents/tests/conftest.py b/hud/agents/tests/conftest.py
index eb4880f4b..2bfd37b0b 100644
--- a/hud/agents/tests/conftest.py
+++ b/hud/agents/tests/conftest.py
@@ -1,42 +1,218 @@
-"""Shared test fixtures for agent tests."""
+# pyright: reportPrivateUsage=false
+"""Shared behavioral harness for agent tests."""
from __future__ import annotations
-from typing import Any
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, ClassVar, cast
import pytest
from mcp import types
+from hud.agents.base import MCPAgent
+from hud.agents.tools import (
+ AgentTool,
+ AgentTools,
+ AgentToolSpec,
+ GroupedCapabilityMixin,
+ ToolMetadata,
+)
+from hud.agents.tools.base import ToolClient
+from hud.agents.types import AgentConfig
from hud.environment.router import ToolRouter
+from hud.environment.scenarios import ScenarioSession
from hud.eval.context import EvalContext
-from hud.types import MCPToolCall, MCPToolResult
+from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
+if TYPE_CHECKING:
+ from collections.abc import Callable, Mapping
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing agents.
- This provides a minimal EvalContext implementation that can be used
- to test agent initialization and tool calling without a real environment.
- """
+class HarnessConfig(AgentConfig):
+ model_name: str = "HarnessAgent"
+ model: str = "harness-model"
+
+
+def mcp_tool(name: str, *, description: str | None = None) -> types.Tool:
+ return types.Tool(
+ name=name,
+ description=description or f"{name} tool",
+ inputSchema={"type": "object", "properties": {}},
+ )
+
+
+def text_prompt(text: str, *, role: types.Role = "user") -> types.PromptMessage:
+ return types.PromptMessage(
+ role=role,
+ content=types.TextContent(type="text", text=text),
+ )
+
+
+def text_result(text: str, *, is_error: bool = False) -> MCPToolResult:
+ return MCPToolResult(
+ content=[types.TextContent(type="text", text=text)],
+ isError=is_error,
+ )
+
+
+def result_text(result: MCPToolResult) -> str:
+ return "\n".join(block.text for block in result.content if isinstance(block, types.TextContent))
+
+
+class HarnessTool(AgentTool[dict[str, Any]]):
+ name = "function"
+ capability = "function"
+
+ @classmethod
+ def from_tool(cls, tool: types.Tool) -> HarnessTool:
+ return cls(
+ env_tool_name=tool.name,
+ spec=AgentToolSpec(api_type="function", api_name=tool.name),
+ )
+
+ @property
+ def provider_name(self) -> str:
+ return self.env_tool_name
+
+ def to_params(self) -> dict[str, Any]:
+ return {"name": self.provider_name}
+
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]:
+ return {
+ "role": "tool",
+ "name": call.name,
+ "content": result_text(result),
+ "is_error": result.isError,
+ }
+
+
+class HarnessTools(AgentTools[HarnessTool, dict[str, Any]]):
+ function_tool_class = HarnessTool
+
+
+class HarnessNativeShellTool(HarnessTool):
+ name = "shell"
+ capability = "shell"
+
+ @property
+ def provider_name(self) -> str:
+ return self.name
+
+ @classmethod
+ def default_spec(cls, model: str) -> AgentToolSpec:
+ del model
+ return AgentToolSpec(api_type="shell", api_name="shell")
+
+
+class HarnessFilesystemReadTool(GroupedCapabilityMixin, HarnessTool):
+ name = "read_file"
+ capability = "filesystem"
+ env_tool_names: ClassVar[tuple[str, ...]] = ("read", "read_file")
+
+ @property
+ def provider_name(self) -> str:
+ return self.name
+
+ @classmethod
+ def default_spec(cls, model: str) -> AgentToolSpec:
+ del model
+ return AgentToolSpec(api_type="function", api_name="read_file")
+
+
+class RoutingHarnessTools(AgentTools[HarnessTool, dict[str, Any]]):
+ native_tool_classes = (HarnessNativeShellTool, HarnessFilesystemReadTool)
+ function_tool_class = HarnessTool
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {"shell": ("bash",)}
+
+
+class ScriptedAgent(MCPAgent[dict[str, Any]]):
+ """Agent fake that exercises the real `MCPAgent.run` loop."""
+
+ def __init__(
+ self,
+ responses: list[AgentResponse | BaseException],
+ *,
+ config: HarnessConfig | None = None,
+ tools_factory: Callable[[], AgentTools[Any, Any]] | None = None,
+ ) -> None:
+ super().__init__(config or HarnessConfig())
+ self.config: HarnessConfig
+ self.responses = list(responses)
+ self.seen_messages: list[list[dict[str, Any]]] = []
+ self._tools_factory = tools_factory or HarnessTools
+
+ @cached_property
+ def tools(self) -> AgentTools[Any, Any]:
+ return self._tools_factory()
+
+ async def format_messages(self, messages: list[types.PromptMessage]) -> list[dict[str, Any]]:
+ formatted: list[dict[str, Any]] = []
+ for message in messages:
+ content = message.content
+ formatted.append(
+ {
+ "role": message.role,
+ "content": content.text if isinstance(content, types.TextContent) else "",
+ }
+ )
+ return formatted
+
+ async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse:
+ self.seen_messages.append([dict(message) for message in messages])
+ response = self.responses.pop(0)
+ if isinstance(response, BaseException):
+ raise response
+ return response
+
+
+class RecordingToolEnvironment:
+ """Records the environment-facing MCP calls made by an agent run."""
+
+ def __init__(
+ self,
+ tools: list[types.Tool] | None = None,
+ *,
+ results: Mapping[str, MCPToolResult | Exception] | None = None,
+ tool_metadata: ToolMetadata | None = None,
+ ) -> None:
+ self.tools = tools or []
+ self.results = dict(results or {})
+ self.tool_metadata = tool_metadata
+ self.calls: list[MCPToolCall] = []
+
+ @property
+ def client(self) -> ToolClient:
+ return ToolClient(
+ tools=self.tools,
+ tool_handler=self.call_tool,
+ tool_metadata=self.tool_metadata,
+ )
+
+ async def call_tool(self, call: MCPToolCall) -> MCPToolResult:
+ self.calls.append(call)
+ result = self.results.get(call.name, text_result(f"result from {call.name}"))
+ if isinstance(result, Exception):
+ raise result
+ return result
+
+
+class HarnessEvalContext(EvalContext):
+ """Small EvalContext double that keeps the real `_run` and prompt behavior."""
def __init__(
self,
prompt: str = "Test prompt",
+ *,
tools: list[types.Tool] | None = None,
- call_tool_handler: Any = None,
+ tool_results: Mapping[str, MCPToolResult | Exception] | None = None,
+ metadata: dict[str, Any] | None = None,
) -> None:
- # Core attributes
self.prompt = prompt
- self._tools = tools or []
+ self.environment = RecordingToolEnvironment(tools or [], results=tool_results)
self._submitted: str | dict[str, Any] | None = None
self.reward: float | None = None
- self._call_tool_handler = call_tool_handler
- self.tool_calls: list[tuple[str, dict[str, Any]]] = []
-
- # Environment attributes
self._router = ToolRouter()
-
- # EvalContext attributes
+ self._scenario_sessions = {}
self._task = None
self.trace_id = "test-trace-id"
self.eval_name = "test-eval"
@@ -47,85 +223,61 @@ def __init__(
self.answer: str | dict[str, Any] | None = None
self.system_prompt: str | None = None
self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
+ self.metadata = metadata or {}
self.results: list[Any] = []
self._is_summary = False
+ self._eval_api_key: str | None = None
+ self._trace_enabled = False
def as_tools(self) -> list[types.Tool]:
- return self._tools
+ return self.environment.tools
@property
- def has_scenario(self) -> bool:
- return False
+ def submitted(self) -> str | dict[str, Any] | None:
+ return self._submitted
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
+ def set_scenario_messages(self, messages: list[types.PromptMessage]) -> None:
+ self._scenario_sessions["__client__"] = ScenarioSession(
+ local_name="chat",
+ full_name="test-env:chat",
+ is_local=True,
+ connection_name=None,
+ resource_uri="test-env:chat",
+ prompt_messages=messages,
+ )
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- # Parse the call
- if isinstance(call, tuple):
- name, args = call[0], call[1] if len(call) > 1 else {}
- elif hasattr(call, "name"):
- name, args = call.name, getattr(call, "arguments", {}) or {}
- else:
- name, args = str(call), kwargs
+ def tool_metadata_for_run(self) -> ToolMetadata | None:
+ return self._tool_metadata()
- self.tool_calls.append((name, args))
+ async def run_agent(self, agent: Any, *, max_steps: int = 10) -> Trace:
+ return await self._run(agent, max_steps=max_steps)
- if self._call_tool_handler:
- tc = MCPToolCall(name=name, arguments=args)
- return self._call_tool_handler(tc)
+ async def list_tools(self, **kwargs: Any) -> list[types.Tool]:
+ del kwargs
+ return self.environment.tools
- return MCPToolResult(
- content=[types.TextContent(type="text", text=f"Result from {name}")],
- isError=False,
- )
+ async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
+ if isinstance(call, MCPToolCall):
+ tool_call = call
+ elif isinstance(call, tuple):
+ call_tuple = cast("tuple[Any, ...]", call)
+ tool_call = MCPToolCall(
+ name=str(call_tuple[0]),
+ arguments=cast("dict[str, Any]", call_tuple[1] if len(call_tuple) > 1 else {}),
+ )
+ else:
+ tool_call = MCPToolCall(name=str(call), arguments=kwargs)
+ return await self.environment.call_tool(tool_call)
async def submit(self, answer: str | dict[str, Any]) -> None:
self._submitted = answer
@pytest.fixture
-def mock_eval_context() -> MockEvalContext:
- """Create a basic mock EvalContext."""
- return MockEvalContext()
+def basic_tool() -> types.Tool:
+ return mcp_tool("lookup")
@pytest.fixture
-def mock_eval_context_with_tools() -> MockEvalContext:
- """Create a mock EvalContext with test tools."""
- return MockEvalContext(
- tools=[
- types.Tool(
- name="test_tool",
- description="A test tool",
- inputSchema={"type": "object", "properties": {}},
- )
- ]
- )
-
-
-@pytest.fixture
-def mock_eval_context_computer() -> MockEvalContext:
- """Create a mock EvalContext with computer tool."""
- return MockEvalContext(
- tools=[
- types.Tool(
- name="computer",
- description="Computer use tool",
- inputSchema={"type": "object"},
- )
- ]
- )
-
-
-@pytest.fixture
-def mock_eval_context_browser_tools() -> MockEvalContext:
- """Create a mock EvalContext with browser-like tools."""
- return MockEvalContext(
- tools=[
- types.Tool(name="screenshot", description="Take screenshot", inputSchema={}),
- types.Tool(name="click", description="Click at coordinates", inputSchema={}),
- types.Tool(name="type", description="Type text", inputSchema={}),
- ]
- )
+def recording_environment(basic_tool: types.Tool) -> RecordingToolEnvironment:
+ return RecordingToolEnvironment([basic_tool])
diff --git a/hud/agents/tests/test_base.py b/hud/agents/tests/test_base.py
deleted file mode 100644
index ef6fa7d0f..000000000
--- a/hud/agents/tests/test_base.py
+++ /dev/null
@@ -1,537 +0,0 @@
-"""Tests for MCPAgent base class with the EvalContext pattern."""
-
-from __future__ import annotations
-
-from typing import Any, ClassVar
-
-import pytest
-from mcp import types
-
-from hud.agents import MCPAgent
-from hud.agents.base import BaseCreateParams
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-
-
-class MockConfig(BaseAgentConfig):
- model_name: str = "MockAgent"
- model: str = "mock-model"
-
-
-class MockCreateParams(BaseCreateParams, MockConfig):
- pass
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing."""
-
- def __init__(
- self,
- prompt: str = "Test prompt",
- tools: list[types.Tool] | None = None,
- ) -> None:
- # Core attributes
- self.prompt = prompt
- self._tools = tools or [
- types.Tool(name="test_tool", description="A test tool", inputSchema={}),
- types.Tool(name="another_tool", description="Another tool", inputSchema={}),
- ]
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
- self._tool_calls: list[tuple[str, dict[str, Any]]] = []
-
- # Environment attributes
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
- self.results: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return True
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- # Parse the call
- if isinstance(call, tuple):
- name, args = call[0], call[1] if len(call) > 1 else {}
- elif hasattr(call, "name"):
- name, args = call.name, getattr(call, "arguments", {}) or {}
- else:
- name, args = str(call), kwargs
- self._tool_calls.append((name, args))
- return MCPToolResult(
- content=[types.TextContent(type="text", text=f"Result from {name}")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class MockMCPAgent(MCPAgent):
- """Concrete implementation of MCPAgent for testing."""
-
- metadata: ClassVar[dict[str, Any] | None] = {}
- config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for the mock agent."""
- return AgentType.OPENAI
-
- def __init__(self, **kwargs: Any) -> None:
- params = MockCreateParams(**kwargs)
- super().__init__(params)
- self._response = InferenceResult(content="Mock response", tool_calls=[], done=True)
-
- def set_response(self, response: InferenceResult) -> None:
- self._response = response
-
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
- return self._response
-
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[dict[str, Any]]:
- formatted = []
- for tool_call, result in zip(tool_calls, tool_results, strict=True):
- formatted.append({"role": "tool", "name": tool_call.name, "content": str(result)})
- return formatted
-
- async def get_system_messages(self) -> list[Any]:
- return []
-
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
- return [{"type": "text", "text": getattr(b, "text", "")} for b in blocks]
-
-
-class TestMCPAgentInit:
- """Tests for MCPAgent initialization."""
-
- def test_init_defaults(self) -> None:
- """Test agent initializes with default config."""
- agent = MockMCPAgent()
- assert agent.ctx is None
- assert agent._initialized is False
- assert agent.system_prompt is None
-
- def test_init_with_system_prompt(self) -> None:
- """Test agent with custom system prompt."""
- agent = MockMCPAgent(system_prompt="Custom prompt")
- assert agent.system_prompt == "Custom prompt"
-
-
-class TestMCPAgentRun:
- """Tests for MCPAgent.run() with EvalContext."""
-
- @pytest.mark.asyncio
- async def test_run_basic(self) -> None:
- """Test basic run flow with EvalContext."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
-
- result = await agent.run(ctx)
-
- assert result.done is True
- assert result.content == "Mock response"
- assert ctx._submitted == "Mock response"
-
- @pytest.mark.asyncio
- async def test_run_initializes_agent(self) -> None:
- """Test run() initializes the agent with context."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
-
- assert not agent._initialized
- await agent.run(ctx)
- assert agent._initialized
-
- @pytest.mark.asyncio
- async def test_run_discovers_tools(self) -> None:
- """Test run() discovers tools from context."""
- tools = [
- types.Tool(name="tool1", description="Tool 1", inputSchema={}),
- types.Tool(name="tool2", description="Tool 2", inputSchema={}),
- ]
- ctx = MockEvalContext(prompt="Do something", tools=tools)
- agent = MockMCPAgent()
-
- # We need to check tools before cleanup
- # Store a reference to check
- discovered_tools = []
-
- original_run = agent._run_context
-
- async def capture_tools(*args: Any, **kwargs: Any) -> Any:
- discovered_tools.extend(agent.get_available_tools())
- return await original_run(*args, **kwargs)
-
- agent._run_context = capture_tools # type: ignore
- await agent.run(ctx)
-
- assert len(discovered_tools) == 2
- assert discovered_tools[0].name == "tool1"
- assert discovered_tools[1].name == "tool2"
-
- @pytest.mark.asyncio
- async def test_run_requires_eval_context(self) -> None:
- """Test run() raises TypeError for non-EvalContext."""
- agent = MockMCPAgent()
-
- with pytest.raises(TypeError, match="must be EvalContext"):
- await agent.run("not a context") # type: ignore
-
- @pytest.mark.asyncio
- async def test_run_requires_prompt(self) -> None:
- """Test run() raises ValueError when prompt is empty."""
- ctx = MockEvalContext(prompt="")
- agent = MockMCPAgent()
-
- with pytest.raises(ValueError, match="prompt is not set"):
- await agent.run(ctx)
-
- @pytest.mark.asyncio
- async def test_run_clears_context_after(self) -> None:
- """Test run() clears ctx after completion."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
-
- await agent.run(ctx)
- assert agent.ctx is None
-
- @pytest.mark.asyncio
- async def test_run_no_submit_on_empty_content(self) -> None:
- """Test run() doesn't submit when content is empty."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
- agent.set_response(InferenceResult(content="", tool_calls=[], done=True))
-
- await agent.run(ctx)
- assert ctx._submitted is None
-
-
-class TestMCPAgentToolCalling:
- """Tests for tool calling through context."""
-
- @pytest.mark.asyncio
- async def test_call_tools_uses_context(self) -> None:
- """Test call_tools routes through ctx.call_tool."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
-
- # Bind context manually
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Call a tool
- results = await agent.call_tools(MCPToolCall(name="test_tool", arguments={"arg": "value"}))
-
- assert len(results) == 1
- assert not results[0].isError
- assert ("test_tool", {"arg": "value"}) in ctx._tool_calls
-
- @pytest.mark.asyncio
- async def test_call_tools_without_context_raises(self) -> None:
- """Test call_tools raises when no context bound."""
- agent = MockMCPAgent()
-
- with pytest.raises(ValueError, match="not bound to context"):
- await agent.call_tools(MCPToolCall(name="test_tool", arguments={}))
-
-
-class TestMCPAgentRequiredTools:
- """Tests for required_tools validation."""
-
- @pytest.mark.asyncio
- async def test_missing_required_tools_raises(self) -> None:
- """Test run() raises when required tools are missing."""
-
- class AgentWithRequiredTools(MockMCPAgent):
- required_tools: ClassVar[list[str]] = ["must_have_tool"]
-
- ctx = MockEvalContext(prompt="Do something", tools=[])
- agent = AgentWithRequiredTools()
-
- with pytest.raises(ValueError, match="Required tools are missing"):
- await agent.run(ctx)
-
- @pytest.mark.asyncio
- async def test_required_tools_present_succeeds(self) -> None:
- """Test run() succeeds when required tools are present."""
-
- class AgentWithRequiredTools(MockMCPAgent):
- required_tools: ClassVar[list[str]] = ["required_tool"]
-
- tools = [types.Tool(name="required_tool", description="Required", inputSchema={})]
- ctx = MockEvalContext(prompt="Do something", tools=tools)
- agent = AgentWithRequiredTools()
-
- result = await agent.run(ctx)
- assert result.done
-
-
-class TestMCPAgentOnToolsReady:
- """Tests for _on_tools_ready hook."""
-
- @pytest.mark.asyncio
- async def test_on_tools_ready_called(self) -> None:
- """Test _on_tools_ready is called during initialization."""
- hook_called = [False]
-
- class AgentWithHook(MockMCPAgent):
- def _on_tools_ready(self) -> None:
- hook_called[0] = True
-
- ctx = MockEvalContext(prompt="Do something")
- agent = AgentWithHook()
-
- await agent.run(ctx)
- assert hook_called[0]
-
- @pytest.mark.asyncio
- async def test_on_tools_ready_has_access_to_tools(self) -> None:
- """Test _on_tools_ready can access discovered tools."""
- captured_tools: list[types.Tool] = []
-
- class AgentWithHook(MockMCPAgent):
- def _on_tools_ready(self) -> None:
- captured_tools.extend(self.get_available_tools())
-
- tools = [
- types.Tool(name="tool1", description="Tool 1", inputSchema={}),
- types.Tool(name="tool2", description="Tool 2", inputSchema={}),
- ]
- ctx = MockEvalContext(prompt="Do something", tools=tools)
- agent = AgentWithHook()
-
- await agent.run(ctx)
-
- assert len(captured_tools) == 2
- assert captured_tools[0].name == "tool1"
-
-
-class TestMCPAgentToolSchemas:
- """Tests for tool schema generation."""
-
- @pytest.mark.asyncio
- async def test_get_tool_schemas(self) -> None:
- """Test get_tool_schemas returns correct format."""
- tools = [
- types.Tool(
- name="my_tool",
- description="My tool description",
- inputSchema={"type": "object", "properties": {"x": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(prompt="Do something", tools=tools)
- agent = MockMCPAgent()
-
- # Initialize agent
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- schemas = agent.get_tool_schemas()
- assert len(schemas) == 1
- assert schemas[0]["name"] == "my_tool"
- assert schemas[0]["description"] == "My tool description"
-
-
-class TestMCPAgentErrorPropagation:
- """Tests for error propagation to EvalContext."""
-
- @pytest.mark.asyncio
- async def test_exception_propagates_to_ctx_error(self) -> None:
- """Test that exceptions during run() set ctx.error for platform visibility."""
-
- class FailingAgent(MockMCPAgent):
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
- raise RuntimeError("Agent crashed")
-
- ctx = MockEvalContext(prompt="Do something")
- agent = FailingAgent()
-
- result = await agent.run(ctx)
-
- # Should return error trace
- assert result.isError is True
- assert result.content is not None
- assert "Agent crashed" in result.content
-
- assert ctx.error is not None
- assert isinstance(ctx.error, BaseException)
- assert "Agent crashed" in str(ctx.error)
-
- @pytest.mark.asyncio
- async def test_step_error_propagates_to_ctx_error(self) -> None:
- """Test that step-level errors (caught internally) set ctx.error."""
- step_count = [0]
-
- class FailOnSecondStepAgent(MockMCPAgent):
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
- step_count[0] += 1
- if step_count[0] == 1:
- return InferenceResult(
- content="",
- tool_calls=[MCPToolCall(name="test_tool", arguments={})],
- done=False,
- )
- else:
- raise ValueError("Step 2 failed")
-
- ctx = MockEvalContext(prompt="Do something")
- agent = FailOnSecondStepAgent()
-
- result = await agent.run(ctx)
-
- # Should return error trace
- assert result.isError is True
- assert ctx.error is not None
- assert "Step 2 failed" in str(ctx.error)
-
- @pytest.mark.asyncio
- async def test_no_error_when_successful(self) -> None:
- """Test that ctx.error remains None on successful run."""
- ctx = MockEvalContext(prompt="Do something")
- agent = MockMCPAgent()
-
- result = await agent.run(ctx)
-
- assert result.isError is False
- assert ctx.error is None
-
-
-class TestMCPAgentCategorizeTools:
- """Tests for the categorize_tools method."""
-
- @pytest.mark.asyncio
- async def test_categorize_generic_tools(self) -> None:
- """All MCP tools are generic unless a provider agent filters them."""
- tools = [
- types.Tool(name="tool1", description="Tool 1", inputSchema={}),
- types.Tool(name="tool2", description="Tool 2", inputSchema={}),
- ]
- ctx = MockEvalContext(prompt="Test", tools=tools)
- agent = MockMCPAgent()
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- categorized = agent.categorize_tools()
-
- assert len(categorized.generic) == 2
- assert len(categorized.skipped) == 0
-
- @pytest.mark.asyncio
- async def test_ignores_legacy_native_tool_metadata(self) -> None:
- """Legacy native metadata no longer affects base categorization."""
- tool_with_metadata = types.Tool(
- name="tool_with_metadata",
- description="Tool with ignored metadata",
- inputSchema={},
- _meta={
- "native_tools": {
- "openai": {
- "api_type": "test_type",
- "role": "test_role",
- }
- }
- },
- )
- tools = [tool_with_metadata]
-
- ctx = MockEvalContext(prompt="Test", tools=tools)
- agent = MockMCPAgent()
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- categorized = agent.categorize_tools()
-
- assert len(categorized.generic) == 1
- assert categorized.generic[0].name == "tool_with_metadata"
- assert len(categorized.skipped) == 0
-
- @pytest.mark.asyncio
- async def test_no_role_exclusion_from_legacy_metadata(self) -> None:
- """Tool role metadata is not a control plane anymore."""
- first_tool = types.Tool(
- name="claude_computer",
- description="Claude computer",
- inputSchema={},
- _meta={
- "native_tools": {
- "openai": {
- "api_type": "computer_test",
- "role": "computer",
- }
- }
- },
- )
- second_tool = types.Tool(
- name="gemini_computer",
- description="Gemini computer",
- inputSchema={},
- _meta={
- "native_tools": {
- "gemini": {
- "api_type": "computer_use",
- "role": "computer",
- }
- }
- },
- )
- tools = [first_tool, second_tool]
-
- ctx = MockEvalContext(prompt="Test", tools=tools)
- agent = MockMCPAgent()
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- categorized = agent.categorize_tools()
-
- assert [tool.name for tool in categorized.generic] == ["claude_computer", "gemini_computer"]
- assert len(categorized.skipped) == 0
-
- @pytest.mark.asyncio
- async def test_hosted_metadata_stays_generic(self) -> None:
- """Hosted tools are configured on agents, not environment metadata."""
- hosted_tool = types.Tool(
- name="google_search",
- description="Google Search",
- inputSchema={},
- _meta={
- "native_tools": {
- "openai": {
- "api_type": "google_search",
- "hosted": True,
- }
- }
- },
- )
- tools = [hosted_tool]
-
- ctx = MockEvalContext(prompt="Test", tools=tools)
- agent = MockMCPAgent()
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- categorized = agent.categorize_tools()
-
- assert [tool.name for tool in categorized.generic] == ["google_search"]
diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py
deleted file mode 100644
index 1a4eec41a..000000000
--- a/hud/agents/tests/test_base_runtime.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""Runtime tests for MCPAgent base class."""
-
-from __future__ import annotations
-
-from typing import Any
-
-import mcp.types as types
-import pytest
-
-from hud.agents.base import BaseCreateParams, MCPAgent, text_to_blocks
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-
-
-class DummyConfig(BaseAgentConfig):
- model_name: str = "DummyAgent"
- model: str = "dummy-model"
-
-
-class DummyCreateParams(BaseCreateParams, DummyConfig):
- pass
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing."""
-
- def __init__(
- self,
- prompt: str = "Test prompt",
- tools: list[types.Tool] | None = None,
- ) -> None:
- # Core attributes
- self.prompt = prompt
- self._tools = tools or []
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
- self._call_tool_handler: Any = None
-
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
- self.results: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return False
-
- def set_call_tool_handler(self, handler: Any) -> None:
- self._call_tool_handler = handler
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- if self._call_tool_handler:
- # Parse the call
- if isinstance(call, tuple):
- tc = MCPToolCall(name=call[0], arguments=call[1] if len(call) > 1 else {})
- elif hasattr(call, "name"):
- tc = call
- else:
- tc = MCPToolCall(name=str(call), arguments=kwargs)
- return self._call_tool_handler(tc)
- return MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class DummyAgent(MCPAgent):
- config_cls = DummyConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for the dummy agent."""
- return AgentType.OPENAI
-
- def __init__(self, **kwargs: Any) -> None:
- params = DummyCreateParams(**kwargs)
- super().__init__(params)
-
- async def get_system_messages(self) -> list[types.ContentBlock]:
- return [types.TextContent(type="text", text="sys")]
-
- async def get_response(self, messages: list[Any]) -> InferenceResult:
- return InferenceResult(content="ok", tool_calls=[], done=True)
-
- async def format_blocks(self, blocks: list[Any]) -> list[Any]:
- return blocks
-
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[Any]:
- return [types.TextContent(text="tools", type="text")]
-
-
-def test_get_available_tools_before_run_raises() -> None:
- """Test that get_available_tools raises before initialization."""
- agent = DummyAgent()
- with pytest.raises(RuntimeError):
- agent.get_available_tools()
-
-
-@pytest.mark.asyncio
-async def test_format_message_invalid_type_raises() -> None:
- """Test that format_message raises for invalid types."""
- agent = DummyAgent()
- with pytest.raises(ValueError):
- await agent.format_message({"oops": 1}) # type: ignore
-
-
-def test_text_to_blocks_shapes() -> None:
- """Test text_to_blocks returns correct structure."""
- blocks = text_to_blocks("x")
- assert isinstance(blocks, list) and blocks and isinstance(blocks[0], types.TextContent)
-
-
-@pytest.mark.asyncio
-async def test_run_with_eval_context() -> None:
- """Test basic run() with EvalContext."""
- ctx = MockEvalContext(prompt="hello")
- agent = DummyAgent()
- result = await agent.run(ctx, max_steps=1)
- assert result.done is True
- assert result.isError is False
-
-
-@pytest.mark.asyncio
-async def test_run_requires_eval_context() -> None:
- """Test run() raises TypeError for non-EvalContext."""
- agent = DummyAgent()
- with pytest.raises(TypeError, match="must be EvalContext"):
- await agent.run("hello") # type: ignore
-
-
-@pytest.mark.asyncio
-async def test_run_requires_prompt() -> None:
- """Test run() raises ValueError when prompt is empty."""
- ctx = MockEvalContext(prompt="")
- agent = DummyAgent()
- with pytest.raises(ValueError, match="prompt is not set"):
- await agent.run(ctx)
-
-
-@pytest.mark.asyncio
-async def test_call_tools_error_paths() -> None:
- """Test call_tools handles errors correctly."""
- call_count = [0]
- ok_result = MCPToolResult(content=text_to_blocks("ok"), isError=False)
-
- def handler(tool_call: MCPToolCall) -> MCPToolResult:
- call_count[0] += 1
- if call_count[0] == 1:
- return ok_result
- raise RuntimeError("boom")
-
- ctx = MockEvalContext(prompt="test")
- ctx.set_call_tool_handler(handler)
- agent = DummyAgent()
-
- # Initialize the agent with context
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- results = await agent.call_tools(
- [MCPToolCall(name="a", arguments={}), MCPToolCall(name="b", arguments={})]
- )
- assert results[0].isError is False
- assert results[1].isError is True
-
-
-@pytest.mark.asyncio
-async def test_call_tools_timeout_raises() -> None:
- """Test call_tools raises TimeoutError."""
-
- def handler(tool_call: MCPToolCall) -> MCPToolResult:
- raise TimeoutError("timeout")
-
- ctx = MockEvalContext(prompt="test")
- ctx.set_call_tool_handler(handler)
- agent = DummyAgent()
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- with pytest.raises(TimeoutError):
- await agent.call_tools(MCPToolCall(name="x", arguments={}))
-
-
-@pytest.mark.asyncio
-async def test_get_available_tools_after_run() -> None:
- """Test get_available_tools works after initialization."""
- tools = [types.Tool(name="test_tool", description="Test", inputSchema={})]
- ctx = MockEvalContext(prompt="hello", tools=tools)
- agent = DummyAgent()
-
- # Run initializes the agent
- await agent.run(ctx, max_steps=1)
-
- # After cleanup, we can't access tools (ctx is cleared)
- # But during run, tools were available
- assert agent._initialized is True
diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py
deleted file mode 100644
index fb3dab557..000000000
--- a/hud/agents/tests/test_claude.py
+++ /dev/null
@@ -1,1605 +0,0 @@
-"""Tests for Claude MCP Agent implementation."""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-from typing import TYPE_CHECKING, Any, cast
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
-from mcp import types
-
-from hud.agents.claude import (
- ClaudeAgent,
- base64_to_content_block,
- text_to_content_block,
- tool_use_content_block,
-)
-from hud.environment import Environment
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.eval.task import Task
-from hud.types import MCPToolCall, MCPToolResult
-
-if TYPE_CHECKING:
- from collections.abc import Generator
-
- from anthropic.types.beta import BetaMessageParam
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing."""
-
- def __init__(
- self,
- tools: list[types.Tool] | None = None,
- environment_capabilities: dict[str, Any] | None = None,
- ) -> None:
- # Core attributes
- self.prompt = "Test prompt"
- self._tools = tools or []
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
-
- # Environment attributes
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.scenario_enable_citations: bool = False
- self.scenario_returns_schema: dict[str, Any] | None = None
- self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
- self.environment_capabilities = environment_capabilities
- self.results: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return False
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- return MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class MockStreamContextManager:
- """Mock for Claude's streaming context manager."""
-
- def __init__(self, response: MagicMock) -> None:
- self.response = response
-
- async def __aenter__(self) -> MockStreamContextManager:
- return self
-
- async def __aexit__(
- self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any
- ) -> bool:
- return False
-
- def __aiter__(self) -> MockStreamContextManager:
- return self
-
- async def __anext__(self) -> None:
- raise StopAsyncIteration
-
- async def get_final_message(self) -> MagicMock:
- return self.response
-
-
-class MockErrorStreamContextManager:
- """Mock stream context manager that raises a fixed error while streaming."""
-
- def __init__(self, error: Exception) -> None:
- self.error = error
-
- async def __aenter__(self) -> MockErrorStreamContextManager:
- return self
-
- async def __aexit__(
- self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any
- ) -> bool:
- return False
-
- def __aiter__(self) -> MockErrorStreamContextManager:
- return self
-
- async def __anext__(self) -> None:
- raise self.error
-
- async def get_final_message(self) -> MagicMock:
- raise AssertionError("get_final_message should not be called when stream iteration fails")
-
-
-class TestClaudeHelperFunctions:
- """Test helper functions for Claude message formatting."""
-
- def test_base64_to_content_block(self) -> None:
- """Test base64 image conversion."""
- base64_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk"
- result = base64_to_content_block(base64_data)
-
- assert result["type"] == "image"
- assert result["source"]["type"] == "base64"
- assert result["source"]["media_type"] == "image/png"
- assert result["source"]["data"] == base64_data
-
- def test_text_to_content_block(self) -> None:
- """Test text conversion."""
- text = "Hello, world!"
- result = text_to_content_block(text)
-
- assert result["type"] == "text"
- assert result["text"] == text
-
- def test_tool_use_content_block(self) -> None:
- """Test tool result content block creation."""
- tool_use_id = "tool_123"
- content = [text_to_content_block("Result text")]
-
- result = tool_use_content_block(tool_use_id, content)
-
- assert result["type"] == "tool_result"
- assert result["tool_use_id"] == tool_use_id
- assert result["content"] == content # type: ignore
-
-
-class TestClaudeAgent:
- """Test ClaudeAgent class."""
-
- @pytest.fixture
- def mock_anthropic(self) -> Generator[AsyncAnthropic, None, None]: # type: ignore[misc]
- """Create a stub Anthropic client."""
- with patch("hud.agents.claude.agent.AsyncAnthropic") as mock_class:
- client = MagicMock(spec=AsyncAnthropic)
- client.api_key = "test-key"
- mock_class.return_value = client
- yield client # type: ignore[misc]
-
- @pytest.mark.asyncio
- async def test_init_with_client(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test agent initialization with provided client."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-6",
- validate_api_key=False,
- )
-
- assert agent.model_name == "Claude"
- assert agent.config.model == "claude-sonnet-4-6"
- assert agent.anthropic_client == mock_anthropic
-
- @pytest.mark.asyncio
- async def test_init_with_parameters(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test agent initialization with various parameters."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-6",
- max_tokens=4096,
- validate_api_key=False,
- )
-
- assert agent.max_tokens == 4096
-
- @pytest.mark.asyncio
- async def test_format_blocks_text_only(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test formatting text content blocks."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Hello, world!"),
- types.TextContent(type="text", text="How are you?"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- assert messages[0]["role"] == "user"
- content = messages[0]["content"]
- assert isinstance(content, list)
- assert len(content) == 2
- assert content[0]["type"] == "text" # type: ignore[index]
- assert content[0]["text"] == "Hello, world!" # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_format_blocks_with_image(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test formatting image content blocks."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Look at this:"),
- types.ImageContent(type="image", data="base64data", mimeType="image/png"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- content = messages[0]["content"]
- assert isinstance(content, list)
- assert len(content) == 2
- assert content[1]["type"] == "image" # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_format_tool_results_text(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test formatting tool results with text content."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Tool output")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- assert len(messages) == 1
- assert messages[0]["role"] == "user"
- content = messages[0]["content"]
- assert isinstance(content, list)
- assert len(content) == 1
- assert content[0]["type"] == "tool_result" # type: ignore[index]
- assert content[0]["tool_use_id"] == "call_123" # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_format_tool_results_with_error(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test formatting tool results with error."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Error message")],
- isError=True,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- assert len(messages) == 1
- content = messages[0]["content"]
- # Error content should include "Error:" prefix
- assert any("Error" in str(block) for block in content[0]["content"]) # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_get_system_messages(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test that system messages return empty (Claude uses system param)."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- system_prompt="You are a helpful assistant.",
- validate_api_key=False,
- )
-
- messages = await agent.get_system_messages()
- # Claude doesn't use system messages in the message list
- assert messages == []
-
- @pytest.mark.asyncio
- async def test_get_response_with_thinking(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test getting model response with thinking content."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- # Set up agent as initialized
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- mock_response = MagicMock()
-
- thinking_block = MagicMock()
- thinking_block.type = "thinking"
- thinking_block.thinking = "Let me analyze this problem..."
-
- text_block = MagicMock()
- text_block.type = "text"
- text_block.text = "Here is the answer"
-
- mock_response.content = [thinking_block, text_block]
- mock_response.usage = MagicMock(input_tokens=10, output_tokens=30)
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- messages = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Hard question"}]},
- )
- ]
- response = await agent.get_response(messages)
-
- assert response.content == "Here is the answer"
- assert response.reasoning == "Let me analyze this problem..."
-
- @pytest.mark.asyncio
- async def test_convert_tools_for_claude(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test converting MCP tools to Claude format."""
- tools = [
- types.Tool(
- name="my_tool",
- description="A test tool",
- inputSchema={"type": "object", "properties": {"x": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Check that tools were converted
- assert len(agent.claude_tools) == 1
- assert agent.claude_tools[0]["name"] == "my_tool" # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_computer_tool_detection(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test that computer tools are detected for beta API."""
- tools = [
- types.Tool(
- name="computer",
- description="Control computer",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.has_computer_tool is True
-
- @pytest.mark.asyncio
- async def test_computer_name_activates_agent_side_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude native computer calls route through the agent-side tool."""
- tools = [
- types.Tool(
- name="computer",
- description="HUD computer",
- inputSchema={
- "type": "object",
- "properties": {"action": {"type": "string"}, "x": {"type": "integer"}},
- },
- _meta={"resolution": {"width": 1280, "height": 720}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="clicked")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- results = await agent.call_tools(
- MCPToolCall(
- name="computer",
- arguments={"action": "left_click", "coordinate": [10, 20]},
- )
- )
-
- assert results[0].isError is False
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "computer"
- assert called.arguments == {
- "action": "click",
- "x": 10,
- "y": 20,
- "hold_keys": None,
- }
-
- @pytest.mark.asyncio
- async def test_env_level_capability_activates_agent_side_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Env-level capabilities are the preferred binding source."""
- tools = [
- types.Tool(
- name="desktop",
- description="Computer",
- inputSchema={"type": "object", "properties": {"action": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(
- tools=tools,
- environment_capabilities={
- "capabilities": {
- "computer": {
- "tool": "desktop",
- "resolution": {"width": 1600, "height": 900},
- }
- }
- },
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["display_width_px"] == 1600 # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["display_height_px"] == 900 # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_anthropic_computer_registration_uses_role_as_capability(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Old Claude native metadata acts only as a capability signal."""
- tools = [
- types.Tool(
- name="anthropic_computer",
- description="Anthropic computer",
- inputSchema={
- "type": "object",
- "properties": {
- "action": {"type": "string"},
- "x": {"type": "integer"},
- "y": {"type": "integer"},
- },
- },
- _meta={
- "native_tools": {
- "claude": {
- "api_type": "stale_env_computer_spec",
- "api_name": "computer",
- "beta": "stale-env-beta",
- "role": "computer",
- "display_width": 1920,
- "display_height": 1080,
- }
- }
- },
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="clicked")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-6",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["type"] == "computer_20251124" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["display_width_px"] != 1920 # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["display_height_px"] != 1080 # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["display_number"] == 1 # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["enable_zoom"] is True # type: ignore[typeddict-item]
- assert agent._required_betas == {"computer-use-2025-11-24"}
-
- await agent.call_tools(
- MCPToolCall(
- name="computer",
- arguments={"action": "left_click", "coordinate": [10, 20]},
- )
- )
-
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "anthropic_computer"
- assert called.arguments == {
- "action": "click",
- "x": 10,
- "y": 20,
- "hold_keys": None,
- }
-
- @pytest.mark.asyncio
- async def test_computer_translates_modifiers_drag_and_hold_key(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude computer actions translate to valid generic environment calls."""
- tools = [
- types.Tool(
- name="computer",
- description="Computer",
- inputSchema={"type": "object", "properties": {"action": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- calls: list[MCPToolCall] = []
-
- async def call_tool(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
-
- ctx.call_tool = call_tool # type: ignore[method-assign]
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- await agent.call_tools(
- [
- MCPToolCall(
- name="computer",
- arguments={
- "action": "right_click",
- "coordinate": [10, 20],
- "text": "Shift",
- },
- ),
- MCPToolCall(
- name="computer",
- arguments={"action": "left_click_drag", "coordinate": [30, 40]},
- ),
- MCPToolCall(
- name="computer",
- arguments={"action": "hold_key", "text": "Control", "duration": 0.5},
- ),
- ]
- )
-
- assert [call.arguments for call in calls] == [
- {
- "action": "click",
- "x": 10,
- "y": 20,
- "button": "right",
- "hold_keys": ["shift"],
- },
- {"action": "mouse_down", "button": "left"},
- {"action": "move", "x": 30, "y": 40},
- {"action": "mouse_up", "button": "left"},
- {"action": "hold_key", "text": "ctrl", "duration": 0.5},
- ]
-
- @pytest.mark.asyncio
- async def test_bash_name_activates_agent_side_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude native bash calls route through the agent-side tool."""
- tools = [
- types.Tool(
- name="bash",
- description="Bash shell",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "bash" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["type"] == "bash_20250124" # type: ignore[typeddict-item]
-
- results = await agent.call_tools(MCPToolCall(name="bash", arguments={"command": "echo ok"}))
-
- assert results[0].isError is False
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "bash"
- assert called.arguments == {"command": "echo ok"}
-
- @pytest.mark.asyncio
- async def test_bash_restart_matches_anthropic_contract(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude bash supports restart without command."""
- tools = [
- types.Tool(
- name="bash",
- description="Bash shell",
- inputSchema={"type": "object", "properties": {}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="Bash session restarted.")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- results = await agent.call_tools(MCPToolCall(name="bash", arguments={"restart": True}))
-
- assert results[0].isError is False
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "bash"
- assert called.arguments == {"restart": True}
-
- @pytest.mark.asyncio
- async def test_bash_requires_command_unless_restart(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Malformed Claude bash calls fail before reaching the environment."""
- tools = [
- types.Tool(
- name="bash",
- description="Bash shell",
- inputSchema={"type": "object", "properties": {}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock()
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- results = await agent.call_tools(MCPToolCall(name="bash", arguments={}))
-
- assert results[0].isError is True
- assert "command is required" in results[0].content[0].text # type: ignore[attr-defined]
- ctx.call_tool.assert_not_called()
-
- @pytest.mark.asyncio
- async def test_edit_name_activates_agent_side_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude native editor calls route through the environment edit tool."""
- tools = [
- types.Tool(
- name="edit",
- description="File editor",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="edited")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "str_replace_based_edit_tool" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["type"] == "text_editor_20250728" # type: ignore[typeddict-item]
-
- results = await agent.call_tools(
- MCPToolCall(
- name="str_replace_based_edit_tool",
- arguments={
- "command": "str_replace",
- "path": "/tmp/file.txt",
- "old_str": "old",
- "new_str": "new",
- },
- )
- )
-
- assert results[0].isError is False
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "edit"
- assert called.arguments == {
- "command": "replace",
- "path": "/tmp/file.txt",
- "old_text": "old",
- "new_text": "new",
- }
-
- @pytest.mark.asyncio
- async def test_claude_3_7_sonnet_editor_stays_generic(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude 3.7 Sonnet editor support is intentionally not advertised."""
- tools = [
- types.Tool(
- name="edit",
- description="File editor",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-3-7-sonnet-20250219",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert "str_replace_editor" not in agent._claude_native_tools
- assert "str_replace_based_edit_tool" not in agent._claude_native_tools
- assert agent.claude_tools[0]["name"] == "edit" # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_sonnet_4_5_uses_current_native_coding_tools(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Sonnet 4.5 keeps native bash and editor support for compatibility."""
- tools = [
- types.Tool(
- name="bash",
- description="Bash shell",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- ),
- types.Tool(
- name="edit",
- description="File editor",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- ),
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-5",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- tool_types = {tool["name"]: tool.get("type") for tool in agent.claude_tools} # type: ignore[index]
- assert tool_types["bash"] == "bash_20250124"
- assert tool_types["str_replace_based_edit_tool"] == "text_editor_20250728"
-
- @pytest.mark.asyncio
- async def test_sonnet_4_5_uses_20250124_native_computer_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Sonnet 4.5 keeps native computer support on its compatible spec."""
- tools = [
- types.Tool(
- name="computer",
- description="Computer",
- inputSchema={"type": "object", "properties": {"action": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-5",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["type"] == "computer_20250124" # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_20250728_editor_rejects_unsupported_commands(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude 4 editor shape only forwards commands supported by the provider spec."""
- tools = [
- types.Tool(
- name="edit",
- description="File editor",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock()
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
- results = await agent.call_tools(
- MCPToolCall(
- name="str_replace_based_edit_tool",
- arguments={"command": "undo_edit", "path": "/tmp/file.txt"},
- )
- )
-
- assert results[0].isError is True
- assert "does not support command 'undo_edit'" in results[0].content[0].text # type: ignore[attr-defined]
- results = await agent.call_tools(
- MCPToolCall(
- name="str_replace_based_edit_tool",
- arguments={"command": "undo", "path": "/tmp/file.txt"},
- )
- )
- assert results[0].isError is True
- assert "does not support command 'undo'" in results[0].content[0].text # type: ignore[attr-defined]
- ctx.call_tool.assert_not_called()
-
- @pytest.mark.asyncio
- async def test_memory_name_activates_agent_side_tool(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Claude native memory calls route through the environment memory tool."""
- tools = [
- types.Tool(
- name="memory",
- description="Memory",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.call_tool = AsyncMock(
- return_value=MCPToolResult(
- content=[types.TextContent(type="text", text="remembered")],
- isError=False,
- )
- )
- agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "memory" # type: ignore[typeddict-item]
- assert agent.claude_tools[0]["type"] == "memory_20250818" # type: ignore[typeddict-item]
- assert agent._required_betas == set()
-
- results = await agent.call_tools(
- MCPToolCall(name="memory", arguments={"command": "view", "path": "/"})
- )
-
- assert results[0].isError is False
- called = ctx.call_tool.call_args.args[0]
- assert called.name == "memory"
- assert called.arguments == {"command": "view", "path": "/"}
-
- @pytest.mark.asyncio
- async def test_old_sonnet_memory_stays_generic(self, mock_anthropic: AsyncAnthropic) -> None:
- """Claude memory is only advertised natively for the supported current models."""
- tools = [
- types.Tool(
- name="memory",
- description="Memory",
- inputSchema={"type": "object", "properties": {"command": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- model="claude-sonnet-4-5",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent.claude_tools[0]["name"] == "memory" # type: ignore[typeddict-item]
- assert "type" not in agent.claude_tools[0] # type: ignore[operator]
- assert "memory" not in agent._claude_native_tools
-
- @pytest.mark.asyncio
- async def test_get_response_with_text(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test getting response with text output."""
- # Create mock response
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Hello!")]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- response = await agent.get_response([])
- assert response.content == "Hello!"
- assert response.done is True
- assert len(response.tool_calls) == 0
-
- @pytest.mark.asyncio
- async def test_get_response_with_tool_call(self, mock_anthropic: AsyncAnthropic) -> None:
- """Test getting response with tool call."""
- mock_tool_use = MagicMock()
- mock_tool_use.type = "tool_use"
- mock_tool_use.id = "call_123"
- mock_tool_use.name = "my_tool"
- mock_tool_use.input = {"x": "value"}
-
- mock_response = MagicMock()
- mock_response.content = [mock_tool_use]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {"my_tool": "my_tool"}
- agent.has_computer_tool = False
- agent._initialized = True
-
- response = await agent.get_response([])
- assert response.done is False
- assert len(response.tool_calls) == 1
- assert response.tool_calls[0].name == "my_tool"
- assert response.tool_calls[0].arguments == {"x": "value"}
-
- @pytest.mark.asyncio
- async def test_get_response_retries_same_generation_once_on_invalid_streamed_tool_json(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """First invalid streamed tool JSON should retry without adding guidance."""
- invalid_json_error = ValueError(
- "Unable to parse tool parameter JSON from model. Please retry your request or "
- "adjust your "
- 'prompt. Error: expected value at line 1 column 10. JSON: {"labels": bug}'
- )
- first_stream = MockErrorStreamContextManager(invalid_json_error)
-
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Recovered")]
- second_stream = MockStreamContextManager(mock_response)
-
- mock_anthropic.beta.messages.stream = MagicMock(side_effect=[first_stream, second_stream])
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- messages: list[BetaMessageParam] = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Create a Linear ticket"}]},
- )
- ]
-
- response = await agent.get_response(messages)
-
- assert response.content == "Recovered"
- assert mock_anthropic.beta.messages.stream.call_count == 2
- # Original user message + assistant response (no guidance message needed)
- assert len(messages) == 2
- assert messages[1]["role"] == "assistant"
-
- @pytest.mark.asyncio
- async def test_get_response_adds_invalid_json_guidance_after_second_failure(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Second consecutive invalid JSON failure should add INVALID_JSON guidance."""
- invalid_json_error = ValueError(
- "Unable to parse tool parameter JSON from model. Please retry your request or "
- "adjust your "
- 'prompt. Error: expected value at line 1 column 10. JSON: {"labels": bug}'
- )
- first_stream = MockErrorStreamContextManager(invalid_json_error)
- second_stream = MockErrorStreamContextManager(invalid_json_error)
-
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Recovered after guidance")]
- third_stream = MockStreamContextManager(mock_response)
-
- mock_anthropic.beta.messages.stream = MagicMock(
- side_effect=[first_stream, second_stream, third_stream]
- )
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- messages: list[BetaMessageParam] = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Create a Linear ticket"}]},
- )
- ]
-
- response = await agent.get_response(messages)
-
- assert response.content == "Recovered after guidance"
- assert mock_anthropic.beta.messages.stream.call_count == 3
- # Original user message + INVALID_JSON guidance + assistant response
- assert len(messages) == 3
- retry_message = messages[1]
- assert retry_message["role"] == "user"
- retry_content = cast("list[dict[str, Any]]", retry_message["content"])
- assert "INVALID_JSON" in retry_content[0]["text"]
-
- @pytest.mark.asyncio
- async def test_get_response_does_not_retry_unrelated_value_error(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Non-tool-json ValueErrors should propagate immediately."""
- unrelated_error = ValueError("stream exploded for unrelated reason")
- mock_anthropic.beta.messages.stream = MagicMock(
- return_value=MockErrorStreamContextManager(unrelated_error)
- )
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- with pytest.raises(ValueError, match="unrelated reason"):
- await agent.get_response([])
-
- assert mock_anthropic.beta.messages.stream.call_count == 1
-
-
-class TestClaudeAgentBedrock:
- """Test ClaudeAgent class with Bedrock."""
-
- @pytest.fixture
- def bedrock_client(self) -> AsyncAnthropicBedrock:
- """Create a real AsyncAnthropicBedrock client and stub networked methods."""
- client = AsyncAnthropicBedrock(
- aws_access_key="AKIATEST",
- aws_secret_key="secret",
- aws_region="us-east-1",
- )
- # Stub the actual Bedrock call so tests are hermetic.
- client.beta.messages.create = AsyncMock()
- return client
-
- @pytest.mark.asyncio
- async def test_init(self, bedrock_client: AsyncAnthropicBedrock) -> None:
- """Test agent initialization."""
- agent = ClaudeAgent.create(
- model_client=bedrock_client,
- model="test-model-arn",
- validate_api_key=False,
- )
-
- assert agent.model_name == "Claude"
- assert agent.config.model == "test-model-arn"
- assert agent.anthropic_client == bedrock_client
-
- @pytest.mark.asyncio
- async def test_get_response_bedrock_uses_create_not_stream(
- self, bedrock_client: AsyncAnthropicBedrock
- ) -> None:
- """Bedrock path must call messages.create() (Bedrock doesn't support stream())."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = ClaudeAgent.create(
- model_client=bedrock_client,
- model="test-model-arn",
- validate_api_key=False,
- )
-
- # Enable computer tool to verify betas list includes computer-use in Bedrock mode.
- # In real usage, this beta is added by _convert_tools_for_claude when it detects
- # a computer tool. Here we manually set both flags to simulate that.
- agent.has_computer_tool = True
- agent._required_betas.add("computer-use-2025-01-24")
-
- mock_response = MagicMock()
- text_block = MagicMock()
- text_block.type = "text"
- text_block.text = "Hello from Bedrock"
- mock_response.content = [text_block]
-
- bedrock_client.beta.messages.create.return_value = mock_response # type: ignore[union-attr]
-
- messages = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Hi"}]},
- )
- ]
- response = await agent.get_response(messages)
-
- assert response.content == "Hello from Bedrock"
- assert response.tool_calls == []
-
- # Bedrock-specific behavior: uses create() and appends assistant message directly.
- assert not hasattr(bedrock_client.beta.messages, "stream")
- bedrock_client.beta.messages.create.assert_awaited_once() # type: ignore[union-attr]
- assert len(messages) == 2
- assert messages[-1]["role"] == "assistant"
-
- # Ensure the Bedrock call shape is stable.
- _, kwargs = bedrock_client.beta.messages.create.call_args # type: ignore[union-attr]
- assert kwargs["model"] == "test-model-arn"
- assert kwargs["tool_choice"] == {"type": "auto", "disable_parallel_tool_use": True}
- assert "computer-use-2025-01-24" in kwargs["betas"]
-
- @pytest.mark.asyncio
- async def test_get_response_bedrock_missing_boto3_raises_value_error(
- self, bedrock_client: AsyncAnthropicBedrock
- ) -> None:
- """If boto3 isn't installed, Bedrock client import path should raise a clear ValueError."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = ClaudeAgent.create(
- model_client=bedrock_client,
- model="test-model-arn",
- validate_api_key=False,
- )
-
- bedrock_client.beta.messages.create.side_effect = ModuleNotFoundError("boto3") # type: ignore[union-attr]
- messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}]
-
- with pytest.raises(ValueError, match=r"boto3 is required for AWS Bedrock"):
- await agent.get_response(messages) # type: ignore
-
- def test_init_with_bedrock_client_does_not_require_anthropic_api_key(
- self, bedrock_client: AsyncAnthropicBedrock
- ) -> None:
- """Providing model_client should bypass ANTHROPIC_API_KEY validation."""
- with patch("hud.settings.settings.anthropic_api_key", None):
- agent = ClaudeAgent.create(
- model_client=bedrock_client,
- validate_api_key=False,
- )
- assert agent.anthropic_client == bedrock_client
-
-
-class TestClaudeAgentComputerTool20251124:
- """Test ClaudeAgent with the new computer_20251124 tool type."""
-
- @pytest.fixture
- def mock_anthropic(self) -> Any:
- from unittest.mock import MagicMock
-
- return MagicMock(spec=["messages", "beta"])
-
- def test_no_fine_grained_streaming_beta(self, mock_anthropic: Any) -> None:
- """Test that fine-grained-tool-streaming beta is no longer included."""
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- assert "fine-grained-tool-streaming-2025-05-14" not in agent._required_betas
-
-
-class TestClaudeAgentBetaHeader:
- """Test that the Anthropic-Beta header is handled correctly."""
-
- @pytest.fixture
- def mock_anthropic(self) -> Any:
- return MagicMock(spec=["messages", "beta"])
-
- @pytest.mark.asyncio
- async def test_empty_betas_sends_omit_not_empty_list(self, mock_anthropic: Any) -> None:
- """When no tools require a beta, betas should be Omit() not []."""
- from anthropic import Omit
-
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._required_betas = set() # No betas required
- agent._initialized = True
-
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Hello")]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- messages = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Hi"}]},
- )
- ]
- await agent.get_response(messages)
-
- _, kwargs = mock_anthropic.beta.messages.stream.call_args
- assert isinstance(kwargs["betas"], Omit), (
- f"Expected Omit() when no betas required, got {type(kwargs['betas'])}"
- )
-
- @pytest.mark.asyncio
- async def test_nonempty_betas_sends_list(self, mock_anthropic: Any) -> None:
- """When tools require betas, betas should be a list of strings."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = True
- agent._required_betas = {"computer-use-2025-01-24"}
- agent._initialized = True
-
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Hello")]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- messages = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Hi"}]},
- )
- ]
- await agent.get_response(messages)
-
- _, kwargs = mock_anthropic.beta.messages.stream.call_args
- assert isinstance(kwargs["betas"], list)
- assert "computer-use-2025-01-24" in kwargs["betas"]
-
- @pytest.mark.asyncio
- async def test_generic_tools_only_no_beta_header(self, mock_anthropic: Any) -> None:
- """Generic function tools should not produce a beta header."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- tools = [
- types.Tool(
- name="my_tool",
- description="A test tool",
- inputSchema={"type": "object", "properties": {"x": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Generic tools should not add any betas
- assert len(agent._required_betas) == 0
-
- mock_response = MagicMock()
- mock_response.content = [MagicMock(type="text", text="Hello")]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- from anthropic import Omit
-
- messages = [
- cast(
- "BetaMessageParam",
- {"role": "user", "content": [{"type": "text", "text": "Hi"}]},
- )
- ]
- await agent.get_response(messages)
-
- _, kwargs = mock_anthropic.beta.messages.stream.call_args
- assert isinstance(kwargs["betas"], Omit)
-
-
-class TestCitationExtraction:
- """Test citation extraction from BetaTextBlock.citations (modern SDK path)."""
-
- @pytest.fixture
- def mock_anthropic(self) -> AsyncAnthropic:
- client = MagicMock(spec=AsyncAnthropic)
- client.beta = MagicMock()
- client.beta.messages = MagicMock()
- return client
-
- @pytest.mark.asyncio
- async def test_inline_citations_extracted_from_text_block(
- self, mock_anthropic: AsyncAnthropic
- ) -> None:
- """Text blocks with inline citations should populate result.citations."""
- cit1 = MagicMock()
- cit1.cited_text = "Revenue was $1M"
- cit1.document_index = 0
- cit1.document_title = "financials.pdf"
- cit1.start_char_index = 0
- cit1.end_char_index = 15
-
- text_block = MagicMock()
- text_block.type = "text"
- text_block.text = "Revenue was $1M last quarter."
- text_block.citations = [cit1]
-
- mock_response = MagicMock()
- mock_response.content = [text_block]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- result = await agent.get_response([])
-
- assert result.content == "Revenue was $1M last quarter."
- assert len(result.citations) == 1
- assert result.citations[0]["text"] == "Revenue was $1M"
- assert result.citations[0]["source"] == "0"
- assert result.citations[0]["title"] == "financials.pdf"
- assert result.citations[0]["start_index"] == 0
- assert result.citations[0]["end_index"] == 15
-
- @pytest.mark.asyncio
- async def test_no_citations_when_field_is_none(self, mock_anthropic: AsyncAnthropic) -> None:
- """Text blocks without citations should not populate result.citations."""
- text_block = MagicMock()
- text_block.type = "text"
- text_block.text = "No citations here."
- text_block.citations = None
-
- mock_response = MagicMock()
- mock_response.content = [text_block]
-
- mock_stream = MockStreamContextManager(mock_response)
- mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream)
-
- agent = ClaudeAgent.create(
- model_client=mock_anthropic,
- validate_api_key=False,
- )
- agent.claude_tools = []
- agent.tool_mapping = {}
- agent.has_computer_tool = False
- agent._initialized = True
-
- result = await agent.get_response([])
- assert result.citations == []
-
-
-class TestDocumentBlockCitations:
- """Test that document_to_content_block threads enable_citations."""
-
- def test_citations_disabled_by_default(self) -> None:
- from hud.agents.claude import document_to_content_block
-
- block = document_to_content_block(base64_data="AAAA")
- assert "citations" not in block
-
- def test_citations_enabled(self) -> None:
- from hud.agents.claude import document_to_content_block
-
- block = document_to_content_block(base64_data="AAAA", enable_citations=True)
- assert block["citations"] == {"enabled": True} # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_format_tool_results_threads_citations_to_documents(self) -> None:
- """When scenario_enable_citations is True, PDF document blocks become siblings with citations.""" # noqa: E501
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = True
-
- client = MagicMock(spec=AsyncAnthropic)
- client.beta = MagicMock()
- client.beta.messages = MagicMock()
- agent = ClaudeAgent.create(
- model_client=client,
- validate_api_key=False,
- )
- agent.ctx = ctx
- agent._initialized = True
- agent.claude_tools = []
- agent.tool_mapping = {}
-
- pdf_blob = "JVBERi0xLjQ="
- tool_calls = [MCPToolCall(id="call_1", name="get_doc", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[
- types.EmbeddedResource(
- type="resource",
- resource=types.BlobResourceContents(
- uri="file:///doc.pdf", # type: ignore[arg-type]
- mimeType="application/pdf",
- blob=pdf_blob,
- ),
- )
- ],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- content_blocks = cast("list[dict[str, Any]]", messages[0]["content"])
- tool_result_block = content_blocks[0]
- assert tool_result_block["type"] == "tool_result"
- assert tool_result_block["content"], "tool_result should contain the PDF block"
- assert tool_result_block["content"][0]["type"] == "document"
- doc_block = content_blocks[1]
- assert doc_block["type"] == "document"
- assert doc_block["citations"] == {"enabled": True}
-
- @pytest.mark.asyncio
- async def test_format_tool_results_wraps_text_as_document_when_citations_enabled(self) -> None:
- """Text tool results produce a sibling document block for citations."""
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = True
-
- client = MagicMock(spec=AsyncAnthropic)
- client.beta = MagicMock()
- client.beta.messages = MagicMock()
- agent = ClaudeAgent.create(
- model_client=client,
- validate_api_key=False,
- )
- agent.ctx = ctx
- agent._initialized = True
- agent.claude_tools = []
- agent.tool_mapping = {}
-
- tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Revenue was $1M last quarter.")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- content_blocks = cast("list[dict[str, Any]]", messages[0]["content"])
- tool_result_block = content_blocks[0]
- assert tool_result_block["type"] == "tool_result"
- text_block = tool_result_block["content"][0]
- assert text_block["type"] == "text"
- assert text_block["text"] == "Revenue was $1M last quarter."
- doc_block = content_blocks[1]
- assert doc_block["type"] == "document"
- assert doc_block["source"]["type"] == "text"
- assert doc_block["source"]["data"] == "Revenue was $1M last quarter."
- assert doc_block["citations"] == {"enabled": True}
- assert doc_block["title"] == "get_sales"
-
- @pytest.mark.asyncio
- async def test_remote_task_setup_preserves_citations_for_tool_results(
- self, monkeypatch: pytest.MonkeyPatch
- ) -> None:
- """Remote task setup should propagate enable_citations into Claude formatting."""
- env = Environment("test-env")
- task = Task(env=env, scenario="remote-env:solve-task", args={})
- ctx = EvalContext.from_task(task)
-
- async def successful_get_prompt(
- _name: str, _arguments: dict[str, str] | None = None
- ) -> Any:
- return SimpleNamespace(
- messages=[
- SimpleNamespace(
- role="user",
- content=SimpleNamespace(text="Prompt"),
- )
- ],
- meta={
- "enable_citations": True,
- "returns_schema": {
- "type": "object",
- "properties": {"summary": {"type": "string"}},
- },
- },
- )
-
- monkeypatch.setattr(ctx, "get_prompt", successful_get_prompt)
- monkeypatch.setattr(ctx._router, "get_prompt_connection", lambda _name: "remote")
-
- await ctx._run_task_scenario_setup()
-
- assert ctx.scenario_enable_citations is True
- session = ctx._get_session()
- assert session is not None
- assert session.enable_citations is True
-
- client = MagicMock(spec=AsyncAnthropic)
- client.beta = MagicMock()
- client.beta.messages = MagicMock()
- agent = ClaudeAgent.create(
- model_client=client,
- validate_api_key=False,
- )
- agent.ctx = ctx
- agent._initialized = True
- agent.claude_tools = []
- agent.tool_mapping = {}
-
- tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Revenue was $1M last quarter.")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- content_blocks = cast("list[dict[str, Any]]", messages[0]["content"])
- doc_block = content_blocks[1]
-
- assert doc_block["type"] == "document"
- assert doc_block["source"]["type"] == "text"
- assert doc_block["source"]["data"] == "Revenue was $1M last quarter."
- assert doc_block["citations"] == {"enabled": True}
-
- @pytest.mark.asyncio
- async def test_format_tool_results_keeps_text_block_when_citations_disabled(self) -> None:
- """Text tool results stay as plain text blocks when citations are off."""
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = False
-
- client = MagicMock(spec=AsyncAnthropic)
- client.beta = MagicMock()
- client.beta.messages = MagicMock()
- agent = ClaudeAgent.create(
- model_client=client,
- validate_api_key=False,
- )
- agent.ctx = ctx
- agent._initialized = True
- agent.claude_tools = []
- agent.tool_mapping = {}
-
- tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Revenue was $1M.")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- content_blocks = cast("list[dict[str, Any]]", messages[0]["content"])
- tool_result_block = content_blocks[0]
- text_block = tool_result_block["content"][0]
- assert text_block["type"] == "text"
- assert text_block["text"] == "Revenue was $1M."
diff --git a/hud/agents/tests/test_gateway_resolution.py b/hud/agents/tests/test_gateway_resolution.py
new file mode 100644
index 000000000..ab016d40a
--- /dev/null
+++ b/hud/agents/tests/test_gateway_resolution.py
@@ -0,0 +1,197 @@
+"""HUD gateway agent resolution tests."""
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from hud.agents import OpenAIAgent, create_agent
+from hud.agents.claude import ClaudeAgent
+from hud.agents.gateway import GatewayModelsResponse, build_gateway_client
+from hud.agents.openai_compatible import OpenAIChatAgent
+
+MODELS = GatewayModelsResponse.model_validate(
+ {
+ "models": [
+ {
+ "id": "uuid-openai",
+ "name": "GPT 5.4",
+ "model_name": "gpt-5.4",
+ "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"},
+ },
+ {
+ "id": "uuid-claude",
+ "name": "Claude Sonnet 4.6",
+ "model_name": "claude-sonnet-4-6",
+ "provider": {"name": "Anthropic", "default_sdk_agent_type": "claude"},
+ },
+ {
+ "id": "uuid-grok",
+ "name": "Grok 4.1 Fast",
+ "model_name": "grok-4-1-fast",
+ "provider": {"name": "xAI", "default_sdk_agent_type": "openai_compatible"},
+ },
+ {
+ "id": "uuid-operator",
+ "name": "Operator",
+ "model_name": "computer-use-preview",
+ "sdk_agent_type": "operator",
+ "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"},
+ },
+ {
+ "id": "uuid-gemini-cua",
+ "name": "Gemini Computer Use",
+ "model_name": "gemini-2.5-computer-use-preview",
+ "sdk_agent_type": "gemini_cua",
+ "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"},
+ },
+ ]
+ }
+).models
+
+
+def test_create_agent_resolves_gateway_model_to_provider_agent() -> None:
+ expected = MagicMock()
+ client = MagicMock()
+ with (
+ patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS),
+ patch("hud.agents.gateway.build_gateway_client", return_value=client) as build_client,
+ patch.object(OpenAIAgent, "create", return_value=expected) as create,
+ ):
+ agent = create_agent("gpt-5.4", temperature=0.5)
+
+ assert agent is expected
+ build_client.assert_called_once_with("OpenAI")
+ create.assert_called_once()
+ assert create.call_args.kwargs["model"] == "gpt-5.4"
+ assert create.call_args.kwargs["model_client"] is client
+ assert create.call_args.kwargs["temperature"] == 0.5
+
+
+@pytest.mark.parametrize("model_alias", ["uuid-openai", "GPT 5.4", "gpt-5.4"])
+def test_create_agent_resolves_gateway_model_aliases(model_alias: str) -> None:
+ expected = MagicMock()
+ with (
+ patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS),
+ patch("hud.agents.gateway.build_gateway_client", return_value=MagicMock()),
+ patch.object(OpenAIAgent, "create", return_value=expected) as create,
+ ):
+ agent = create_agent(model_alias)
+
+ assert agent is expected
+ assert create.call_args.kwargs["model"] == "gpt-5.4"
+
+
+def test_create_agent_shortcut_uses_gateway_provider() -> None:
+ expected = MagicMock()
+ with (
+ patch("hud.agents.gateway.build_gateway_client", return_value=MagicMock()) as build_client,
+ patch.object(ClaudeAgent, "create", return_value=expected),
+ ):
+ agent = create_agent("claude")
+
+ assert agent is expected
+ build_client.assert_called_once_with("anthropic")
+
+
+def test_create_agent_openai_compatible_models_use_chat_agent_client() -> None:
+ expected = MagicMock()
+ client = MagicMock()
+ with (
+ patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS),
+ patch("hud.agents.gateway.build_gateway_client", return_value=client),
+ patch.object(OpenAIChatAgent, "create", return_value=expected) as create,
+ ):
+ agent = create_agent("grok-4-1-fast")
+
+ assert agent is expected
+ assert create.call_args.kwargs["openai_client"] is client
+ assert "model_client" not in create.call_args.kwargs
+
+
+@pytest.mark.parametrize(
+ ("model", "message"),
+ [
+ ("missing-model", "not found"),
+ ("computer-use-preview", "Operator agent is no longer supported"),
+ ("gemini-2.5-computer-use-preview", "Gemini CUA agent is no longer supported"),
+ ],
+)
+def test_create_agent_rejects_unknown_or_stale_gateway_models(model: str, message: str) -> None:
+ with (
+ patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS),
+ pytest.raises(ValueError, match=message),
+ ):
+ create_agent(model)
+
+
+def test_create_agent_rejects_gateway_model_with_invalid_agent_metadata() -> None:
+ models = GatewayModelsResponse.model_validate(
+ {
+ "models": [
+ {
+ "id": "bad-model",
+ "name": "Bad Model",
+ "model_name": "bad-model",
+ "provider": {"name": "OpenAI", "default_sdk_agent_type": None},
+ }
+ ]
+ }
+ ).models
+
+ with (
+ patch("hud.agents.gateway._fetch_gateway_models", return_value=models),
+ pytest.raises(ValueError, match="invalid agent type metadata"),
+ ):
+ create_agent("bad-model")
+
+
+def test_build_gateway_client_uses_openai_compatible_client_by_default() -> None:
+ with (
+ patch("hud.agents.gateway.settings") as settings,
+ patch("hud.agents.gateway.AsyncOpenAI") as client_cls,
+ ):
+ settings.api_key = "hud-key"
+ settings.hud_gateway_url = "https://gateway.example"
+
+ build_gateway_client("together")
+
+ client_cls.assert_called_once_with(
+ api_key="hud-key",
+ base_url="https://gateway.example",
+ )
+
+
+def test_build_gateway_client_uses_anthropic_client_for_anthropic_provider() -> None:
+ with (
+ patch("hud.agents.gateway.settings") as settings,
+ patch("anthropic.AsyncAnthropic") as client_cls,
+ ):
+ settings.api_key = "hud-key"
+ settings.hud_gateway_url = "https://gateway.example"
+
+ build_gateway_client("anthropic")
+
+ client_cls.assert_called_once_with(
+ api_key="hud-key",
+ base_url="https://gateway.example",
+ )
+
+
+def test_build_gateway_client_uses_genai_client_for_gemini_provider() -> None:
+ with (
+ patch("hud.agents.gateway.settings") as settings,
+ patch("google.genai.Client") as client_cls,
+ ):
+ settings.api_key = "hud-key"
+ settings.hud_gateway_url = "https://gateway.example"
+
+ build_gateway_client("gemini")
+
+ client_cls.assert_called_once()
+ assert client_cls.call_args.kwargs["api_key"] == "PLACEHOLDER"
+ http_options = client_cls.call_args.kwargs["http_options"]
+ assert http_options.api_version == "v1beta"
+ assert http_options.base_url == "https://gateway.example"
+ assert http_options.headers == {"Authorization": "Bearer hud-key"}
diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py
deleted file mode 100644
index dcaa7f309..000000000
--- a/hud/agents/tests/test_gemini.py
+++ /dev/null
@@ -1,1064 +0,0 @@
-"""Tests for Gemini MCP Agent implementation."""
-
-from __future__ import annotations
-
-import base64
-from typing import Any
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from google import genai
-from google.genai import types as genai_types
-from mcp import types
-
-from hud.agents.gemini import GeminiAgent
-from hud.agents.gemini.tools import GeminiComputerTool as AgentGeminiComputerTool
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.types import MCPToolCall, MCPToolResult
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing."""
-
- def __init__(self, tools: list[types.Tool] | None = None) -> None:
- # Core attributes
- self.prompt = "Test prompt"
- self._tools = tools or []
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
-
- # Environment attributes
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.error: BaseException | None = None
- self.scenario_enable_citations: bool = False
- self.scenario_returns_schema: dict[str, Any] | None = None
- self.metadata: dict[str, Any] = {}
- self.results: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return False
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- return MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class TestGeminiAgent:
- """Test GeminiAgent base class."""
-
- @pytest.fixture
- def mock_gemini_client(self) -> MagicMock:
- """Create a stub Gemini client."""
- client = MagicMock(spec=genai.Client)
- client.api_key = "test_key"
- client.models = MagicMock()
- client.models.list = MagicMock(return_value=iter([]))
- client.models.generate_content = MagicMock()
- # Set up async interface (aio.models.generate_content)
- client.aio = MagicMock()
- client.aio.models = MagicMock()
- client.aio.models.generate_content = AsyncMock()
- return client
-
- @pytest.mark.asyncio
- async def test_init(self, mock_gemini_client: MagicMock) -> None:
- """Test agent initialization."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- model="gemini-2.5-flash",
- validate_api_key=False,
- )
-
- assert agent.model_name == "Gemini"
- assert agent.config.model == "gemini-2.5-flash"
- assert agent.gemini_client == mock_gemini_client
-
- @pytest.mark.asyncio
- async def test_init_without_model_client(self) -> None:
- """Test agent initialization without model client."""
- with (
- patch("hud.settings.settings.gemini_api_key", "test_key"),
- patch("hud.agents.gemini.agent.genai.Client") as mock_client_class,
- ):
- mock_client = MagicMock()
- mock_client.api_key = "test_key"
- mock_client.models = MagicMock()
- mock_client.models.list = MagicMock(return_value=iter([]))
- mock_client_class.return_value = mock_client
-
- agent = GeminiAgent.create(
- model="gemini-2.5-flash",
- validate_api_key=False,
- )
-
- assert agent.gemini_client is not None
-
- @pytest.mark.asyncio
- async def test_format_blocks_text_only(self, mock_gemini_client: MagicMock) -> None:
- """Test formatting text content blocks."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Hello, world!"),
- types.TextContent(type="text", text="How are you?"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- assert messages[0].role == "user"
- assert messages[0].parts is not None
- assert len(messages[0].parts) == 2
-
- @pytest.mark.asyncio
- async def test_format_blocks_with_image(self, mock_gemini_client: MagicMock) -> None:
- """Test formatting image content blocks."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- # Create a tiny valid base64 PNG
- png_data = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode()
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Look at this:"),
- types.ImageContent(type="image", data=png_data, mimeType="image/png"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- assert messages[0].parts is not None
- assert len(messages[0].parts) == 2
-
- @pytest.mark.asyncio
- async def test_format_tool_results(self, mock_gemini_client: MagicMock) -> None:
- """Test formatting tool results."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Tool output")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- assert len(messages) == 1
- assert messages[0].role == "user"
-
- @pytest.mark.asyncio
- async def test_get_system_messages(self, mock_gemini_client: MagicMock) -> None:
- """Test that system messages return empty (Gemini uses system_instruction)."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- system_prompt="You are a helpful assistant.",
- validate_api_key=False,
- )
-
- messages = await agent.get_system_messages()
- # Gemini doesn't use system messages in the message list
- assert messages == []
-
- @pytest.mark.asyncio
- async def test_get_response_text_only(self, mock_gemini_client: MagicMock) -> None:
- """Test getting text-only response."""
- # Disable telemetry for this test
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- # Set up agent as initialized (no tools needed for this test)
- agent.gemini_tools = []
- agent._initialized = True
-
- # Mock the API response with text only
- mock_response = MagicMock()
- mock_candidate = MagicMock()
-
- text_part = MagicMock()
- text_part.text = "Task completed successfully"
- text_part.function_call = None
-
- mock_candidate.content = MagicMock()
- mock_candidate.content.parts = [text_part]
-
- mock_response.candidates = [mock_candidate]
-
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
-
- messages = [
- genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")])
- ]
- response = await agent.get_response(messages)
-
- assert response.content == "Task completed successfully"
- assert response.tool_calls == []
- assert response.done is True
-
- @pytest.mark.asyncio
- async def test_get_response_raises_on_no_candidates(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """A no-candidate Gemini response should fail loudly, not submit an empty answer."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- model="gemini-3-flash-preview",
- validate_api_key=False,
- )
- agent.gemini_tools = []
- agent._initialized = True
-
- mock_response = MagicMock()
- mock_response.candidates = []
- mock_response.prompt_feedback = "blocked"
- mock_response.usage_metadata = None
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
-
- messages = [
- genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")])
- ]
-
- with pytest.raises(RuntimeError, match="returned no candidates"):
- await agent.get_response(messages)
-
- @pytest.mark.asyncio
- async def test_get_response_with_thinking(self, mock_gemini_client: MagicMock) -> None:
- """Test getting response with thinking content."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- # Set up agent as initialized (no tools needed for this test)
- agent.gemini_tools = []
- agent._initialized = True
-
- mock_response = MagicMock()
- mock_candidate = MagicMock()
-
- thinking_part = MagicMock()
- thinking_part.text = "Let me reason through this..."
- thinking_part.function_call = None
- thinking_part.thought = True
-
- text_part = MagicMock()
- text_part.text = "Here is my answer"
- text_part.function_call = None
- text_part.thought = False
-
- mock_candidate.content = MagicMock()
- mock_candidate.content.parts = [thinking_part, text_part]
-
- mock_response.candidates = [mock_candidate]
-
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
-
- messages = [
- genai_types.Content(
- role="user", parts=[genai_types.Part.from_text(text="Hard question")]
- )
- ]
- response = await agent.get_response(messages)
-
- assert response.content == "Here is my answer"
- assert response.reasoning == "Let me reason through this..."
-
- @pytest.mark.asyncio
- async def test_get_response_passes_thinking_config(self, mock_gemini_client: MagicMock) -> None:
- """Gemini 3 thinking options should be passed to GenerateContentConfig."""
- with patch("hud.settings.settings.telemetry_enabled", False):
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- model="gemini-3-flash-preview",
- validate_api_key=False,
- thinking_level="high",
- include_thoughts=True,
- )
- agent.gemini_tools = []
- agent._initialized = True
-
- mock_response = MagicMock()
- mock_candidate = MagicMock()
- text_part = MagicMock()
- text_part.text = "Answer"
- text_part.function_call = None
- text_part.thought = False
- mock_candidate.content = MagicMock()
- mock_candidate.content.parts = [text_part]
- mock_response.candidates = [mock_candidate]
-
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response)
-
- messages = [
- genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Hi")])
- ]
- await agent.get_response(messages)
-
- config = mock_gemini_client.aio.models.generate_content.call_args.kwargs["config"]
- assert config.thinking_config is not None
- assert config.thinking_config.include_thoughts is True
- assert config.thinking_config.thinking_level.value == "HIGH"
-
- @pytest.mark.asyncio
- async def test_convert_tools_for_gemini(self, mock_gemini_client: MagicMock) -> None:
- """Test converting MCP tools to Gemini format."""
- tools = [
- types.Tool(
- name="my_tool",
- description="A test tool",
- inputSchema={"type": "object", "properties": {"x": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Check that tools were converted
- assert len(agent.gemini_tools) == 1
- # Gemini tools have function_declarations - cast to genai Tool type
- gemini_tool = agent.gemini_tools[0]
- assert isinstance(gemini_tool, genai_types.Tool)
- assert gemini_tool.function_declarations is not None
- assert gemini_tool.function_declarations[0].name == "my_tool"
-
- @pytest.mark.asyncio
- async def test_regular_agent_uses_native_computer_use(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """GeminiAgent should register GeminiComputerTool as native Computer Use."""
- computer_tool = types.Tool(
- name="gemini_computer",
- description="Control computer with mouse, keyboard, and screenshots",
- inputSchema={"type": "object", "properties": {}},
- )
- computer_tool.meta = {
- "native_tools": {
- "gemini": {
- "api_type": "computer_use",
- "api_name": "gemini_computer",
- "role": "computer",
- "supported_models": ["gemini-3-flash-preview"],
- }
- }
- }
- tools = [
- computer_tool,
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- model="gemini-3-flash-preview",
- validate_api_key=False,
- excluded_predefined_functions=["drag_and_drop"],
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent._computer_tool_name == "computer_use"
- assert agent._gemini_native_tools["computer_use"].env_tool_name == "gemini_computer"
- assert "gemini_computer" not in agent._gemini_native_tools
- assert len(agent.gemini_tools) == 1
- computer_tool = agent.gemini_tools[0]
- assert isinstance(computer_tool, genai_types.Tool)
- assert computer_tool.computer_use is not None
- assert computer_tool.computer_use.excluded_predefined_functions == ["drag_and_drop"]
-
- @pytest.mark.asyncio
- async def test_computer_use_excludes_colliding_generic_tool_names(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """Generic tools named like predefined actions should not be hijacked."""
- computer_tool = types.Tool(
- name="gemini_computer",
- description="Control computer with mouse, keyboard, and screenshots",
- inputSchema={"type": "object", "properties": {}},
- )
- computer_tool.meta = {
- "native_tools": {
- "gemini": {
- "api_type": "computer_use",
- "api_name": "gemini_computer",
- "role": "computer",
- "supported_models": ["gemini-3-flash-preview"],
- }
- }
- }
- navigate_tool = types.Tool(
- name="navigate",
- description="A non-computer navigation helper",
- inputSchema={"type": "object", "properties": {"url": {"type": "string"}}},
- )
- ctx = MockEvalContext(tools=[computer_tool, navigate_tool])
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- model="gemini-3-flash-preview",
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- computer_use_tool = next(
- tool for tool in agent.gemini_tools if getattr(tool, "computer_use", None) is not None
- )
- computer_use = getattr(computer_use_tool, "computer_use", None)
- assert computer_use is not None
- assert "navigate" in (computer_use.excluded_predefined_functions or [])
- function_call = MagicMock()
- function_call.name = "navigate"
- function_call.args = {"url": "https://example.com"}
- tool_call = agent._extract_tool_call(MagicMock(function_call=function_call))
- assert tool_call is not None
- assert tool_call.name == "navigate"
- assert tool_call.arguments == {"url": "https://example.com"}
-
- @pytest.mark.asyncio
- async def test_agent_owns_gemini_cli_tool_surface(self, mock_gemini_client: MagicMock) -> None:
- """GeminiAgent exposes Gemini-shaped tools backed by generic env primitives."""
- tools = [
- types.Tool(name="bash", description="Run shell", inputSchema={"type": "object"}),
- types.Tool(name="edit", description="Edit files", inputSchema={"type": "object"}),
- types.Tool(name="read", description="Read files", inputSchema={"type": "object"}),
- types.Tool(name="grep", description="Search files", inputSchema={"type": "object"}),
- types.Tool(name="glob", description="Find files", inputSchema={"type": "object"}),
- types.Tool(name="list", description="List files", inputSchema={"type": "object"}),
- types.Tool(name="memory", description="Remember facts", inputSchema={"type": "object"}),
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- agent.console.info = MagicMock()
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- declaration_names = {
- declaration.name
- for tool in agent.gemini_tools
- for declaration in (getattr(tool, "function_declarations", None) or [])
- }
- assert {
- "run_shell_command",
- "replace",
- "write_file",
- "read_file",
- "grep_search",
- "glob",
- "list_directory",
- "save_memory",
- } <= declaration_names
- assert agent._gemini_native_tools["run_shell_command"].env_tool_name == "bash"
- assert agent._gemini_native_tools["replace"].env_tool_name == "edit"
- assert agent._gemini_native_tools["write_file"].env_tool_name == "edit"
- assert agent._gemini_native_tools["read_file"].env_tool_name == "read"
- assert agent._gemini_native_tools["grep_search"].env_tool_name == "grep"
- assert agent._gemini_native_tools["glob"].env_tool_name == "glob"
- assert agent._gemini_native_tools["list_directory"].env_tool_name == "list"
- assert agent._gemini_native_tools["save_memory"].env_tool_name == "memory"
- declarations = {
- declaration.name: declaration
- for tool in agent.gemini_tools
- for declaration in (getattr(tool, "function_declarations", None) or [])
- }
- assert "allow_multiple" not in declarations["replace"].parameters_json_schema["properties"]
- assert (
- "exclude_pattern"
- not in declarations["grep_search"].parameters_json_schema["properties"]
- )
- assert "names_only" not in declarations["grep_search"].parameters_json_schema["properties"]
- assert "respect_git_ignore" not in declarations["glob"].parameters_json_schema["properties"]
- agent.console.info.assert_called_with(
- "Agent initialized with 8 tools: "
- "glob, grep_search, list_directory, read_file, replace, run_shell_command, "
- "save_memory, write_file"
- )
-
- @pytest.mark.asyncio
- async def test_gemini_legacy_env_tools_activate_harness_tools(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """Old Gemini env constructors register canonical names for harness activation."""
- from hud.tools import (
- GeminiGlobTool,
- GeminiListTool,
- GeminiMemoryTool,
- GeminiReadTool,
- GeminiSearchTool,
- )
-
- env_tools = [
- GeminiReadTool(),
- GeminiSearchTool(),
- GeminiGlobTool(),
- GeminiListTool(),
- GeminiMemoryTool(),
- ]
- tools = [
- types.Tool(name=tool.name, description=tool.description, inputSchema={"type": "object"})
- for tool in env_tools
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert agent._gemini_native_tools["read_file"].env_tool_name == "read"
- assert agent._gemini_native_tools["grep_search"].env_tool_name == "grep"
- assert agent._gemini_native_tools["glob"].env_tool_name == "glob"
- assert agent._gemini_native_tools["list_directory"].env_tool_name == "list"
- assert agent._gemini_native_tools["save_memory"].env_tool_name == "memory"
-
- def test_regular_agent_routes_computer_use_function_call(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """Gemini Computer Use calls should route to the MCP computer tool."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- agent._computer_tool_name = "computer_use"
-
- function_call = MagicMock()
- function_call.name = "click_at"
- function_call.args = {"x": 500, "y": 250, "safety_decision": {"decision": "allowed"}}
- part = MagicMock(function_call=function_call)
-
- tool_call = agent._extract_tool_call(part)
-
- assert tool_call is not None
- assert tool_call.name == "computer_use"
- assert tool_call.arguments == {
- "action": "click_at",
- "safety_decision": {"decision": "allowed"},
- "x": 500,
- "y": 250,
- }
- assert getattr(tool_call, "gemini_name") == "click_at"
-
- def test_gemini_computer_drag_insets_edge_coordinates(self) -> None:
- """Gemini drag endpoints should be inset before calling the environment tool."""
- spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview")
- assert spec is not None
- tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec)
-
- calls = tool._env_calls(
- "drag_and_drop",
- {"x": 0, "y": 500, "destination_x": 1000, "destination_y": 500},
- )
-
- assert calls == [
- {
- "action": "drag",
- "path": [
- {"x": 25, "y": 500},
- {"x": 975, "y": 500},
- ],
- }
- ]
-
- def test_gemini_computer_normalizes_keys_and_optional_type_coordinates(self) -> None:
- """Gemini key strings should map cleanly to the environment press contract."""
- spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview")
- assert spec is not None
- tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec)
-
- assert tool._env_calls("key_combination", {"keys": "Control+A"}) == [
- {"action": "press", "keys": ["ctrl", "a"]}
- ]
- assert tool._env_calls("type_text_at", {"text": "hello", "clear_before_typing": False}) == [
- {"action": "write", "text": "hello", "enter_after": False}
- ]
-
- @pytest.mark.asyncio
- async def test_gemini_computer_blocks_confirmation_required_actions(self) -> None:
- """Gemini require_confirmation actions need HITL before execution."""
- spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview")
- assert spec is not None
- tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec)
- calls: list[MCPToolCall] = []
-
- async def call_tool(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(
- content=[types.TextContent(type="text", text="executed")],
- isError=False,
- )
-
- result = await tool.execute(
- call_tool,
- {
- "action": "click_at",
- "x": 10,
- "y": 20,
- "safety_decision": {"decision": "require_confirmation"},
- },
- )
-
- assert result.isError is False
- assert isinstance(result.content[0], types.TextContent)
- assert result.content[0].text.startswith("__GEMINI_SAFETY_BLOCKED__:")
- assert calls == []
-
- @pytest.mark.asyncio
- async def test_regular_agent_formats_computer_use_results(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """GeminiAgent should return URL and screenshot parts for native computer use."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- agent._computer_tool_name = "computer_use"
- screenshot = base64.b64encode(b"png bytes").decode()
- tool_calls = [
- MCPToolCall(
- name="computer_use",
- arguments={"action": "click_at", "safety_decision": {"decision": "allowed"}},
- gemini_name="click_at", # type: ignore[arg-type]
- )
- ]
- tool_results = [
- MCPToolResult(
- content=[
- types.TextContent(type="text", text="__URL__:https://example.com"),
- types.ImageContent(type="image", data=screenshot, mimeType="image/png"),
- ],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
-
- parts = messages[0].parts
- assert parts is not None
- function_response = parts[0].function_response
- assert function_response is not None
- assert function_response.name == "click_at"
- response = function_response.response
- assert response is not None
- assert response["url"] == "https://example.com"
- assert response["safety_acknowledgement"] is True
- assert function_response.parts is not None
- inline_data = function_response.parts[0].inline_data
- assert inline_data is not None
- assert inline_data.mime_type == "image/png"
-
- @pytest.mark.asyncio
- async def test_regular_agent_formats_blocked_computer_use_results(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """Blocked Gemini safety actions should not be reported as tool errors."""
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
- agent._computer_tool_name = "computer_use"
- tool_calls = [
- MCPToolCall(
- name="computer_use",
- arguments={
- "action": "click_at",
- "safety_decision": {"decision": "require_confirmation"},
- },
- gemini_name="click_at", # type: ignore[arg-type]
- )
- ]
- tool_results = [
- MCPToolResult(
- content=[
- types.TextContent(
- type="text",
- text=(
- "__GEMINI_SAFETY_BLOCKED__:Gemini Computer Use action requires "
- "user confirmation before execution."
- ),
- ),
- ],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
-
- parts = messages[0].parts
- assert parts is not None
- function_response = parts[0].function_response
- assert function_response is not None
- response = function_response.response
- assert response is not None
- assert response["blocked"] is True
- assert "success" not in response
- assert response["url"] == "about:blank"
- assert "safety_acknowledgement" not in response
-
-
-class TestGeminiToolConversion:
- """Tests for tool conversion to Gemini format."""
-
- @pytest.fixture
- def mock_gemini_client(self) -> MagicMock:
- """Create a stub Gemini client."""
- client = MagicMock(spec=genai.Client)
- client.api_key = "test_key"
- client.models = MagicMock()
- client.models.list = MagicMock(return_value=iter([]))
- # Set up async interface
- client.aio = MagicMock()
- client.aio.models = MagicMock()
- client.aio.models.generate_content = AsyncMock()
- return client
-
- @pytest.mark.asyncio
- async def test_tool_with_properties(self, mock_gemini_client: MagicMock) -> None:
- """Test tool with input properties."""
- tools = [
- types.Tool(
- name="search",
- description="Search the web",
- inputSchema={
- "type": "object",
- "properties": {
- "query": {"type": "string", "description": "Search query"},
- "limit": {"type": "integer", "description": "Max results"},
- },
- "required": ["query"],
- },
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert len(agent.gemini_tools) == 1
- gemini_tool = agent.gemini_tools[0]
- # Gemini tools have function_declarations - cast to genai Tool type
- assert isinstance(gemini_tool, genai_types.Tool)
- assert gemini_tool.function_declarations is not None
- assert gemini_tool.function_declarations[0].name == "search"
- assert gemini_tool.function_declarations[0].parameters_json_schema is not None
-
- @pytest.mark.asyncio
- async def test_tool_without_schema(self, mock_gemini_client: MagicMock) -> None:
- """Test tool without description raises error."""
- # Create a tool with inputSchema but no description
- tools = [
- types.Tool(
- name="incomplete",
- description=None,
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = GeminiAgent.create(
- model_client=mock_gemini_client,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- with pytest.raises(ValueError, match="requires both a description"):
- await agent._initialize_from_ctx(ctx)
-
-
-class TestGeminiCitations:
- """Tests for Gemini grounding citation extraction."""
-
- @pytest.fixture
- def mock_gemini_client(self) -> MagicMock:
- client = MagicMock(spec=genai.Client)
- client.aio = MagicMock()
- client.aio.models = MagicMock()
- client.aio.models.generate_content = AsyncMock()
- return client
-
- def _make_agent(self, client: MagicMock) -> GeminiAgent:
- agent = GeminiAgent.create(model_client=client, validate_api_key=False)
- agent.gemini_tools = []
- agent._initialized = True
- return agent
-
- def _text_candidate(self, text: str = "answer") -> MagicMock:
- candidate = MagicMock()
- part = MagicMock()
- part.text = text
- part.function_call = None
- part.thought = False
- candidate.content = MagicMock()
- candidate.content.parts = [part]
- return candidate
-
- @pytest.mark.asyncio
- async def test_no_grounding_metadata(self, mock_gemini_client: MagicMock) -> None:
- """No citations when groundingMetadata is absent."""
- agent = self._make_agent(mock_gemini_client)
- candidate = self._text_candidate()
- candidate.grounding_metadata = None
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- result = await agent.get_response([])
- assert result.citations == []
-
- @pytest.mark.asyncio
- async def test_grounding_chunks_only(self, mock_gemini_client: MagicMock) -> None:
- """Chunks without supports produce citations with source but no anchoring."""
- agent = self._make_agent(mock_gemini_client)
- candidate = self._text_candidate()
-
- chunk = MagicMock()
- chunk.web = MagicMock()
- chunk.web.uri = "https://example.com"
- chunk.web.title = "Example"
-
- grounding_meta = MagicMock()
- grounding_meta.grounding_chunks = [chunk]
- grounding_meta.grounding_supports = []
- candidate.grounding_metadata = grounding_meta
-
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- result = await agent.get_response([])
- assert len(result.citations) == 1
- assert result.citations[0]["source"] == "https://example.com"
- assert result.citations[0]["title"] == "Example"
- assert result.citations[0]["text"] == ""
-
- @pytest.mark.asyncio
- async def test_grounding_supports_with_anchoring(self, mock_gemini_client: MagicMock) -> None:
- """Supports produce citations with start_index/end_index from segments."""
- agent = self._make_agent(mock_gemini_client)
- candidate = self._text_candidate("The sky is blue because of Rayleigh scattering.")
-
- chunk = MagicMock()
- chunk.web = MagicMock()
- chunk.web.uri = "https://physics.org/scattering"
- chunk.web.title = "Scattering"
-
- support = MagicMock()
- support.segment = MagicMock()
- support.segment.text = "Rayleigh scattering"
- support.segment.start_index = 28
- support.segment.end_index = 47
- support.grounding_chunk_indices = [0]
-
- grounding_meta = MagicMock()
- grounding_meta.grounding_chunks = [chunk]
- grounding_meta.grounding_supports = [support]
- candidate.grounding_metadata = grounding_meta
-
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- result = await agent.get_response([])
- assert len(result.citations) == 1
- cit = result.citations[0]
- assert cit["type"] == "grounding"
- assert cit["text"] == "Rayleigh scattering"
- assert cit["source"] == "https://physics.org/scattering"
- assert cit["start_index"] == 28
- assert cit["end_index"] == 47
-
- @pytest.mark.asyncio
- async def test_multiple_supports_and_chunks(self, mock_gemini_client: MagicMock) -> None:
- """Multiple supports across multiple chunks produce the right citations."""
- agent = self._make_agent(mock_gemini_client)
- candidate = self._text_candidate()
-
- chunk_a = MagicMock()
- chunk_a.web = MagicMock()
- chunk_a.web.uri = "https://a.com"
- chunk_a.web.title = "A"
-
- chunk_b = MagicMock()
- chunk_b.web = MagicMock()
- chunk_b.web.uri = "https://b.com"
- chunk_b.web.title = "B"
-
- support1 = MagicMock()
- support1.segment = MagicMock()
- support1.segment.text = "fact one"
- support1.segment.start_index = 0
- support1.segment.end_index = 8
- support1.grounding_chunk_indices = [0]
-
- support2 = MagicMock()
- support2.segment = MagicMock()
- support2.segment.text = "fact two"
- support2.segment.start_index = 10
- support2.segment.end_index = 18
- support2.grounding_chunk_indices = [1]
-
- grounding_meta = MagicMock()
- grounding_meta.grounding_chunks = [chunk_a, chunk_b]
- grounding_meta.grounding_supports = [support1, support2]
- candidate.grounding_metadata = grounding_meta
-
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- result = await agent.get_response([])
- assert len(result.citations) == 2
- assert result.citations[0]["source"] == "https://a.com"
- assert result.citations[0]["text"] == "fact one"
- assert result.citations[1]["source"] == "https://b.com"
- assert result.citations[1]["text"] == "fact two"
-
-
-class TestGeminiCitationInjection:
- """Test that enable_citations injects google_search when missing."""
-
- @pytest.fixture
- def mock_gemini_client(self) -> MagicMock:
- client = MagicMock(spec=genai.Client)
- client.aio = MagicMock()
- client.aio.models = MagicMock()
- client.aio.models.generate_content = AsyncMock()
- return client
-
- def _make_agent(self, client: MagicMock) -> GeminiAgent:
- agent = GeminiAgent.create(model_client=client, validate_api_key=False)
- agent.gemini_tools = []
- agent._gemini_to_mcp_tool_map = {}
- agent._initialized = True
- return agent
-
- @pytest.mark.asyncio
- async def test_google_search_injected_when_citations_enabled(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """When scenario_enable_citations=True and no google_search tool, inject one."""
- agent = self._make_agent(mock_gemini_client)
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = True
- agent.ctx = ctx
-
- candidate = MagicMock()
- candidate.content = MagicMock()
- candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")]
- candidate.grounding_metadata = None
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- await agent.get_response([])
-
- call_kwargs = mock_gemini_client.aio.models.generate_content.call_args
- config = call_kwargs.kwargs["config"]
- tools_passed = config.tools
- assert any(
- isinstance(t, genai_types.Tool) and t.google_search is not None for t in tools_passed
- )
-
- @pytest.mark.asyncio
- async def test_no_duplicate_google_search_when_already_present(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """When google_search tool already exists, don't add a second one."""
- agent = self._make_agent(mock_gemini_client)
- existing_search_tool = genai_types.Tool(google_search=genai_types.GoogleSearch())
- agent.gemini_tools = [existing_search_tool]
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = True
- agent.ctx = ctx
-
- candidate = MagicMock()
- candidate.content = MagicMock()
- candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")]
- candidate.grounding_metadata = None
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- await agent.get_response([])
-
- call_kwargs = mock_gemini_client.aio.models.generate_content.call_args
- config = call_kwargs.kwargs["config"]
- tools_passed = config.tools
- search_count = sum(
- 1
- for t in tools_passed
- if isinstance(t, genai_types.Tool) and t.google_search is not None
- )
- assert search_count == 1
-
- @pytest.mark.asyncio
- async def test_no_injection_when_citations_disabled(
- self, mock_gemini_client: MagicMock
- ) -> None:
- """When scenario_enable_citations=False, no google_search is injected."""
- agent = self._make_agent(mock_gemini_client)
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = False
- agent.ctx = ctx
-
- candidate = MagicMock()
- candidate.content = MagicMock()
- candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")]
- candidate.grounding_metadata = None
- resp = MagicMock()
- resp.candidates = [candidate]
- mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp)
-
- await agent.get_response([])
-
- call_kwargs = mock_gemini_client.aio.models.generate_content.call_args
- config = call_kwargs.kwargs["config"]
- tools_passed = config.tools
- assert not any(
- isinstance(t, genai_types.Tool) and t.google_search is not None for t in tools_passed
- )
diff --git a/hud/agents/tests/test_hosted_tools.py b/hud/agents/tests/test_hosted_tools.py
index deee000f3..ce4d76aea 100644
--- a/hud/agents/tests/test_hosted_tools.py
+++ b/hud/agents/tests/test_hosted_tools.py
@@ -1,48 +1,137 @@
+"""Provider-hosted tool configuration tests."""
+
from __future__ import annotations
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock
+
import pytest
+from google.genai import types as genai_types
+from openai.types.responses import ResponseOutputMessage, ResponseOutputText
-from hud.agents.base import CategorizedTools
+from hud.agents.base import AgentContext
from hud.agents.claude import (
ClaudeAgent,
ClaudeToolSearchTool,
ClaudeWebFetchTool,
ClaudeWebSearchTool,
)
-from hud.agents.gemini import (
- GeminiAgent,
- GeminiCodeExecutionTool,
- GeminiGoogleSearchTool,
- GeminiUrlContextTool,
-)
-from hud.agents.openai import (
- OpenAIAgent,
- OpenAICodeInterpreterTool,
- OpenAIToolSearchTool,
-)
+from hud.agents.gemini import GeminiAgent, GeminiCodeExecutionTool, GeminiGoogleSearchTool
+from hud.agents.openai import OpenAIAgent, OpenAICodeInterpreterTool, OpenAIToolSearchTool
+from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt
-def test_claude_agent_configured_hosted_tools() -> None:
- agent = ClaudeAgent.create(
- model_client=object(),
- hosted_tools=[
- ClaudeWebSearchTool(max_uses=3),
- ClaudeWebFetchTool(citations_enabled=True),
- ClaudeToolSearchTool(threshold=7),
+def _message_response(text: str) -> SimpleNamespace:
+ return SimpleNamespace(
+ id="resp",
+ output=[
+ ResponseOutputMessage(
+ id="msg",
+ type="message",
+ role="assistant",
+ status="completed",
+ content=[ResponseOutputText(type="output_text", text=text, annotations=[])],
+ )
],
)
- agent._available_tools = []
- agent._categorized_tools = CategorizedTools()
- agent._convert_tools_for_claude()
- assert {tool.get("type") for tool in agent.claude_tools if isinstance(tool, dict)} == {
- "web_search_20250305",
- "web_fetch_20250910",
- "tool_search_tool_bm25_20251119",
- }
- assert agent._required_betas == set()
- assert agent._tool_search_threshold == 7
+class Stream:
+ def __init__(self, text: str) -> None:
+ block = MagicMock()
+ block.type = "text"
+ block.text = text
+ block.citations = None
+ self.response = MagicMock()
+ self.response.content = [block]
+
+ async def __aenter__(self) -> Stream:
+ return self
+
+ async def __aexit__(self, *args: object) -> bool:
+ return False
+
+ def __aiter__(self) -> Stream:
+ return self
+
+ async def __anext__(self) -> None:
+ raise StopAsyncIteration
+
+ async def get_final_message(self) -> MagicMock:
+ return self.response
+
+
+def _gemini_response(text: str) -> genai_types.GenerateContentResponse:
+ return genai_types.GenerateContentResponse(
+ candidates=[
+ genai_types.Candidate(
+ content=genai_types.Content(role="model", parts=[genai_types.Part(text=text)])
+ )
+ ]
+ )
+
+
+def _gemini_client(response: genai_types.GenerateContentResponse) -> MagicMock:
+ client = MagicMock()
+ client.aio = MagicMock()
+ client.aio.models = MagicMock()
+ client.aio.models.generate_content = AsyncMock(return_value=response)
+ return client
+
+
+def test_openai_hosted_tools_are_model_gated() -> None:
+ tool = OpenAICodeInterpreterTool(container={"type": "auto"})
+
+ assert tool.supports_model("gpt-5.4")
+ assert not tool.supports_model("gpt-4.1")
+
+
+@pytest.mark.asyncio
+async def test_supported_openai_hosted_tool_is_sent_to_provider() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done")))
+ )
+ agent = OpenAIAgent.create(
+ model="gpt-5.4",
+ model_client=client,
+ validate_api_key=False,
+ hosted_tools=[OpenAICodeInterpreterTool(container={"type": "auto"})],
+ )
+
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("use hosted code")],
+ tool_client=RecordingToolEnvironment().client,
+ )
+ )
+
+ assert result.content == "done"
+ tools = client.responses.create.await_args.kwargs["tools"]
+ assert any(tool["type"] == "code_interpreter" for tool in tools)
+
+
+@pytest.mark.asyncio
+async def test_unsupported_openai_hosted_tool_is_not_sent_to_provider() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done")))
+ )
+ agent = OpenAIAgent.create(
+ model="gpt-4.1",
+ model_client=client,
+ validate_api_key=False,
+ hosted_tools=[OpenAICodeInterpreterTool(container={"type": "auto"})],
+ )
+
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("use hosted code")],
+ tool_client=RecordingToolEnvironment().client,
+ )
+ )
+
+ assert result.content == "done"
+ tools = client.responses.create.await_args.kwargs["tools"]
+ assert not isinstance(tools, list)
def test_claude_hosted_domain_filters_are_mutually_exclusive() -> None:
@@ -59,68 +148,126 @@ def test_claude_hosted_domain_filters_are_mutually_exclusive() -> None:
).to_params()
-def test_openai_agent_configured_hosted_tools() -> None:
+def test_gemini_google_search_rejects_unsupported_dynamic_threshold() -> None:
+ with pytest.raises(ValueError, match="dynamic_threshold"):
+ GeminiGoogleSearchTool(dynamic_threshold=0.2).to_params()
+
+
+@pytest.mark.asyncio
+async def test_openai_tool_search_threshold_defers_function_loading() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done")))
+ )
agent = OpenAIAgent.create(
- model_client=object(),
- hosted_tools=[
- OpenAICodeInterpreterTool(container={"type": "auto"}),
- OpenAIToolSearchTool(threshold=4),
- ],
+ model="gpt-5.4",
+ model_client=client,
+ validate_api_key=False,
+ hosted_tools=[OpenAIToolSearchTool(threshold=1)],
)
- agent._available_tools = []
- agent._categorized_tools = CategorizedTools()
+ environment = RecordingToolEnvironment([mcp_tool("first"), mcp_tool("second")])
- agent._convert_tools_for_openai()
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("use tools")],
+ tool_client=environment.client,
+ )
+ )
- assert {"code_interpreter", "tool_search"} <= {
- tool.get("type") for tool in agent._openai_tools if isinstance(tool, dict)
- }
- assert agent._tool_search_threshold == 4
+ assert result.content == "done"
+ tools = client.responses.create.await_args.kwargs["tools"]
+ function_tools = [tool for tool in tools if tool["type"] == "function"]
+ assert len(function_tools) == 2
+ assert all(tool["defer_loading"] is True for tool in function_tools)
-def test_openai_hosted_tools_are_model_gated() -> None:
- agent = OpenAIAgent.create(
- model_client=object(),
- model="gpt-4.1",
+@pytest.mark.asyncio
+async def test_claude_hosted_web_fetch_payload_is_sent_to_provider() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(stream=MagicMock(return_value=Stream("done")))
+ )
+ )
+ agent = ClaudeAgent.create(
+ model="claude-sonnet-4-6",
+ model_client=client,
+ validate_api_key=False,
hosted_tools=[
- OpenAICodeInterpreterTool(container={"type": "auto"}),
- OpenAIToolSearchTool(threshold=4),
+ ClaudeWebFetchTool(
+ max_uses=2,
+ allowed_domains=["example.com"],
+ max_content_tokens=500,
+ citations_enabled=True,
+ )
],
)
- agent._available_tools = []
- agent._categorized_tools = CategorizedTools()
- agent._convert_tools_for_openai()
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("fetch")],
+ tool_client=RecordingToolEnvironment().client,
+ )
+ )
- assert agent._openai_tools == []
- assert agent._tool_search_threshold is None
+ assert result.content == "done"
+ tools = client.beta.messages.stream.call_args.kwargs["tools"]
+ assert tools == [
+ {
+ "type": "web_fetch_20250910",
+ "name": "web_fetch",
+ "max_uses": 2,
+ "allowed_domains": ["example.com"],
+ "max_content_tokens": 500,
+ "citations": {"enabled": True},
+ }
+ ]
-def test_gemini_agent_configured_hosted_tools() -> None:
- agent = GeminiAgent.create(
- model_client=object(),
- hosted_tools=[
- GeminiGoogleSearchTool(),
- GeminiUrlContextTool(),
- GeminiCodeExecutionTool(),
- ],
+@pytest.mark.asyncio
+async def test_claude_tool_search_threshold_defers_generic_tools() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(stream=MagicMock(return_value=Stream("done")))
+ )
+ )
+ agent = ClaudeAgent.create(
+ model="claude-sonnet-4-6",
+ model_client=client,
+ validate_api_key=False,
+ hosted_tools=[ClaudeToolSearchTool(threshold=1)],
)
- agent._available_tools = []
- agent._categorized_tools = CategorizedTools()
- agent._convert_tools_for_gemini()
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("use tools")],
+ tool_client=RecordingToolEnvironment([mcp_tool("first"), mcp_tool("second")]).client,
+ )
+ )
- assert any(getattr(tool, "google_search", None) is not None for tool in agent.gemini_tools)
- assert any(getattr(tool, "url_context", None) is not None for tool in agent.gemini_tools)
- assert any(getattr(tool, "code_execution", None) is not None for tool in agent.gemini_tools)
+ assert result.content == "done"
+ tools = client.beta.messages.stream.call_args.kwargs["tools"]
+ generic_tools = [tool for tool in tools if "input_schema" in tool]
+ assert len(generic_tools) == 2
+ assert all(tool["defer_loading"] is True for tool in generic_tools)
+ assert any(tool["type"] == "tool_search_tool_bm25_20251119" for tool in tools)
-def test_gemini_google_search_rejects_unsupported_dynamic_threshold() -> None:
- tool = GeminiGoogleSearchTool(dynamic_threshold=0.2)
-
- try:
- tool.to_params()
- except ValueError as exc:
- assert "dynamic_threshold" in str(exc)
- else:
- raise AssertionError("dynamic_threshold should be rejected")
+@pytest.mark.asyncio
+async def test_gemini_hosted_code_execution_payload_is_sent_to_provider() -> None:
+ client = _gemini_client(_gemini_response("done"))
+ agent = GeminiAgent.create(
+ model_client=client,
+ validate_api_key=False,
+ hosted_tools=[GeminiCodeExecutionTool()],
+ )
+
+ result = await agent.run(
+ AgentContext(
+ messages=[text_prompt("run code")],
+ tool_client=RecordingToolEnvironment().client,
+ )
+ )
+
+ assert result.content == "done"
+ config = client.aio.models.generate_content.await_args.kwargs["config"]
+ assert len(config.tools) == 1
+ assert config.tools[0].code_execution is not None
diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py
deleted file mode 100644
index cd8438628..000000000
--- a/hud/agents/tests/test_openai.py
+++ /dev/null
@@ -1,824 +0,0 @@
-"""Tests for OpenAI MCP Agent implementation."""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-from typing import TYPE_CHECKING, Any, cast
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from mcp import types
-from openai import AsyncOpenAI
-from openai.types.responses import (
- ResponseFunctionToolCall,
- ResponseOutputMessage,
- ResponseOutputText,
- ResponseReasoningItem,
-)
-from openai.types.responses.response_reasoning_item import Summary
-
-from hud.agents.openai import OpenAIAgent
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.types import MCPToolCall, MCPToolResult
-
-if TYPE_CHECKING:
- from collections.abc import Generator
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing."""
-
- def __init__(self, tools: list[types.Tool] | None = None) -> None:
- # Core attributes
- self.prompt = "Test prompt"
- self._tools = tools or []
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
-
- # Environment attributes
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.scenario_enable_citations: bool = False
- self.scenario_returns_schema: dict[str, Any] | None = None
- self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
- self.results: list[Any] = []
- self.calls: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return False
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- self.calls.append(call)
- return MCPToolResult(
- content=[types.TextContent(type="text", text="ok")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class TestOpenAIAgent:
- """Test OpenAIAgent class."""
-
- @pytest.fixture
- def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc]
- """Create a stub OpenAI client."""
- with patch("hud.agents.openai.agent.AsyncOpenAI") as mock_class:
- client = AsyncOpenAI(api_key="test", base_url="http://localhost")
- client.chat.completions.create = AsyncMock()
- client.responses.create = AsyncMock()
- mock_class.return_value = client
- yield client # type: ignore[misc]
-
- @pytest.mark.asyncio
- async def test_init_with_client(self, mock_openai: AsyncOpenAI) -> None:
- """Test agent initialization with provided client."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- model="gpt-4o",
- validate_api_key=False,
- )
-
- assert agent.model_name == "OpenAI"
- assert agent.config.model == "gpt-4o"
- assert agent.model == "gpt-4o"
- assert agent.openai_client == mock_openai
- assert agent.max_output_tokens is None
- assert agent.temperature is None
-
- @pytest.mark.asyncio
- async def test_init_with_parameters(self, mock_openai: AsyncOpenAI) -> None:
- """Test agent initialization with various parameters."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- model="gpt-4o",
- max_output_tokens=2048,
- temperature=0.7,
- reasoning={"effort": "high"},
- tool_choice="auto",
- parallel_tool_calls=True,
- validate_api_key=False,
- )
-
- assert agent.max_output_tokens == 2048
- assert agent.temperature == 0.7
- assert agent.reasoning == {"effort": "high"}
- assert agent.tool_choice == "auto"
- assert agent.parallel_tool_calls is True
-
- @pytest.mark.asyncio
- async def test_init_without_client_no_api_key(self) -> None:
- """Test agent initialization fails without API key."""
- with patch("hud.agents.openai.agent.settings") as mock_settings:
- mock_settings.api_key = None
- mock_settings.openai_api_key = None
- with pytest.raises(ValueError, match="No API key found"):
- OpenAIAgent.create()
-
- @pytest.mark.asyncio
- async def test_format_blocks_text_only(self, mock_openai: AsyncOpenAI) -> None:
- """Test formatting text content blocks."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Hello, world!"),
- types.TextContent(type="text", text="How are you?"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- assert messages[0]["role"] == "user"
- assert len(messages[0]["content"]) == 2
- assert messages[0]["content"][0]["type"] == "input_text"
- assert messages[0]["content"][0]["text"] == "Hello, world!"
-
- @pytest.mark.asyncio
- async def test_format_blocks_with_image(self, mock_openai: AsyncOpenAI) -> None:
- """Test formatting image content blocks."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- blocks: list[types.ContentBlock] = [
- types.TextContent(type="text", text="Look at this:"),
- types.ImageContent(type="image", data="base64data", mimeType="image/png"),
- ]
-
- messages = await agent.format_blocks(blocks)
- assert len(messages) == 1
- assert len(messages[0]["content"]) == 2
- assert messages[0]["content"][1]["type"] == "input_image"
- assert messages[0]["content"][1]["image_url"] == "data:image/png;base64,base64data" # type: ignore[typeddict-item]
-
- @pytest.mark.asyncio
- async def test_format_blocks_empty(self, mock_openai: AsyncOpenAI) -> None:
- """Test formatting empty content blocks."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- messages = await agent.format_blocks([])
- assert len(messages) == 1
- # Empty blocks produce a single empty text item
- assert len(messages[0]["content"]) == 1
- assert messages[0]["content"][0]["type"] == "input_text"
- assert messages[0]["content"][0]["text"] == ""
-
- @pytest.mark.asyncio
- async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None:
- """Test formatting tool results with text content."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Tool output")],
- isError=False,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- assert len(messages) == 1
- assert messages[0]["type"] == "function_call_output"
- assert messages[0]["call_id"] == "call_123"
- # Output is a list of content items
- assert len(messages[0]["output"]) == 1
- assert messages[0]["output"][0]["text"] == "Tool output" # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None:
- """Test formatting tool results with error."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})]
- tool_results = [
- MCPToolResult(
- content=[types.TextContent(type="text", text="Error message")],
- isError=True,
- )
- ]
-
- messages = await agent.format_tool_results(tool_calls, tool_results)
- assert len(messages) == 1
- # Output is a list; first item is error indicator, second is the message
- msg = cast("dict[str, Any]", messages[0])
- output = cast("list[dict[str, Any]]", msg["output"])
- assert any(item.get("text") == "[tool_error] true" for item in output)
- assert any(item.get("text") == "Error message" for item in output)
-
- @pytest.mark.asyncio
- async def test_get_system_messages(self, mock_openai: AsyncOpenAI) -> None:
- """Test getting system messages - OpenAI uses instructions field instead."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- system_prompt="You are a helpful assistant.",
- validate_api_key=False,
- )
-
- # OpenAI agent returns empty list - system prompt is passed via instructions
- messages = await agent.get_system_messages()
- assert len(messages) == 0
-
- @pytest.mark.asyncio
- async def test_convert_tools_for_openai(self, mock_openai: AsyncOpenAI) -> None:
- """Test converting MCP tools to OpenAI format."""
- tools = [
- types.Tool(
- name="my_tool",
- description="A test tool",
- inputSchema={"type": "object", "properties": {"x": {"type": "string"}}},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- # Initialize with context to trigger tool conversion
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Check that tools were converted
- assert len(agent._openai_tools) >= 1
- # Find our tool
- tool = next((t for t in agent._openai_tools if t.get("name") == "my_tool"), None)
- assert tool is not None
- assert tool["type"] == "function"
-
- @pytest.mark.asyncio
- async def test_convert_tools_raises_on_incomplete(self, mock_openai: AsyncOpenAI) -> None:
- """Test that tools without description raise error."""
- tools = [
- types.Tool(
- name="incomplete_tool",
- description=None, # Missing description
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- with pytest.raises(ValueError, match="requires both a description"):
- await agent._initialize_from_ctx(ctx)
-
- @pytest.mark.asyncio
- async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None:
- """Test getting response with text output."""
- # Setup mock response
- mock_response = AsyncMock()
- mock_response.output = [
- ResponseOutputMessage(
- id="msg_123",
- type="message",
- role="assistant",
- status="completed",
- content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])],
- )
- ]
- mock_openai.responses.create = AsyncMock(return_value=mock_response)
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
- # Set empty tools to avoid needing initialization
- agent._openai_tools = []
- agent._initialized = True
-
- response = await agent.get_response([])
- assert response.content == "Hello!"
- assert response.done is True
- assert len(response.tool_calls) == 0
-
- @pytest.mark.asyncio
- async def test_get_response_with_tool_call(self, mock_openai: AsyncOpenAI) -> None:
- """Test getting response with tool call."""
- mock_response = AsyncMock()
- # Tool calls come as separate output items, not inside message content
- mock_response.output = [
- ResponseFunctionToolCall(
- id="call_123",
- type="function_call",
- call_id="call_123",
- name="my_tool",
- arguments='{"x": "value"}',
- )
- ]
- mock_openai.responses.create = AsyncMock(return_value=mock_response)
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._tool_name_map = {"my_tool": "my_tool"}
- agent._initialized = True
-
- response = await agent.get_response([])
- assert response.done is False
- assert len(response.tool_calls) == 1
- assert response.tool_calls[0].name == "my_tool"
- assert response.tool_calls[0].arguments == {"x": "value"}
-
- @pytest.mark.asyncio
- async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> None:
- """Test getting response with reasoning."""
- mock_response = AsyncMock()
- mock_response.output = [
- ResponseReasoningItem(
- id="reason_123",
- type="reasoning",
- summary=[Summary(type="summary_text", text="Thinking about it...")],
- ),
- ResponseOutputMessage(
- id="msg_123",
- type="message",
- role="assistant",
- status="completed",
- content=[ResponseOutputText(type="output_text", text="Answer!", annotations=[])],
- ),
- ]
- mock_openai.responses.create = AsyncMock(return_value=mock_response)
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._initialized = True
-
- response = await agent.get_response([])
- # Reasoning is stored separately from content
- assert response.reasoning == "Thinking about it..."
- assert response.content == "Answer!"
-
- @pytest.mark.asyncio
- async def test_get_response_requests_sources_when_citations_enabled(
- self, mock_openai: AsyncOpenAI
- ) -> None:
- """Scenario citation mode should request source payloads from Responses API."""
- mock_response = AsyncMock()
- mock_response.id = "resp_123"
- mock_response.output = [
- ResponseOutputMessage(
- id="msg_123",
- type="message",
- role="assistant",
- status="completed",
- content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])],
- )
- ]
- mock_openai.responses.create = AsyncMock(return_value=mock_response)
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._initialized = True
-
- ctx = MockEvalContext()
- ctx.scenario_enable_citations = True
- agent.ctx = ctx
-
- await agent.get_response([])
-
- call_kwargs = mock_openai.responses.create.await_args.kwargs # type: ignore[union-attr]
- assert call_kwargs.get("include") == ["web_search_call.action.sources"]
-
-
-class TestOpenAIToolConversion:
- """Tests for tool conversion to OpenAI format."""
-
- @pytest.fixture
- def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc]
- """Create a stub OpenAI client."""
- with patch("hud.agents.openai.agent.AsyncOpenAI") as mock_class:
- client = AsyncOpenAI(api_key="test", base_url="http://localhost")
- client.responses.create = AsyncMock()
- mock_class.return_value = client
- yield client # type: ignore[misc]
-
- @pytest.mark.asyncio
- async def test_shell_tool_conversion(self, mock_openai: AsyncOpenAI) -> None:
- """Test that the agent converts shell capability to OpenAI native format."""
- tools = [
- types.Tool(
- name="bash",
- description="Execute shell commands",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- # Check for native shell tool
- shell_tool = next((t for t in agent._openai_tools if t.get("type") == "shell"), None)
- assert shell_tool == {"type": "shell", "environment": {"type": "local"}}
- assert agent._tool_name_map["shell"] == "shell"
- assert agent._openai_native_tools["shell"].env_tool_name == "bash"
-
- @pytest.mark.asyncio
- async def test_editor_tool_stays_generic(self, mock_openai: AsyncOpenAI) -> None:
- """Editor capabilities are not advertised as OpenAI apply_patch."""
- tools = [
- types.Tool(
- name="edit",
- description="Apply V4A patches",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert all(t.get("type") != "apply_patch" for t in agent._openai_tools)
- assert "apply_patch" not in agent._tool_name_map
- assert "apply_patch" not in agent._openai_native_tools
- assert [tool.get("type") for tool in agent._openai_tools] == ["function"]
- assert agent._openai_tools[0].get("name") == "edit"
-
- @pytest.mark.asyncio
- async def test_capability_metadata_routes_openai_tools(self, mock_openai: AsyncOpenAI) -> None:
- """Test env-level capabilities can bind OpenAI tools to non-public names."""
- tools = [
- types.Tool(
- name="run_shell",
- description="Execute shell commands",
- inputSchema={"type": "object"},
- ),
- types.Tool(
- name="patch_files",
- description="Apply V4A patches",
- inputSchema={"type": "object"},
- ),
- ]
- ctx = MockEvalContext(tools=tools)
- ctx.metadata["environment_capabilities"] = {
- "capabilities": {
- "shell": "run_shell",
- "editor": {"tool": "patch_files"},
- }
- }
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert {t.get("type") for t in agent._openai_tools} == {"shell", "function"}
- assert agent._tool_name_map["shell"] == "shell"
- assert agent._openai_native_tools["shell"].env_tool_name == "run_shell"
- assert "apply_patch" not in agent._tool_name_map
- assert "apply_patch" not in agent._openai_native_tools
- assert [tool.name for tool in agent._categorized_tools.generic] == [
- "run_shell",
- "patch_files",
- ]
-
- @pytest.mark.asyncio
- async def test_non_hosted_native_metadata_is_generic(self, mock_openai: AsyncOpenAI) -> None:
- """OpenAI ignores env-owned provider metadata."""
- tools = [
- types.Tool(
- name="custom_tool",
- description="Custom tool",
- inputSchema={"type": "object", "properties": {}},
- _meta={
- "native_tools": {
- "openai": {
- "api_type": "custom_native",
- "api_name": "custom_native",
- "role": "custom",
- }
- }
- },
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False)
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- assert [tool.name for tool in agent._categorized_tools.generic] == ["custom_tool"]
- assert {tool.get("type") for tool in agent._openai_tools} == {"function"}
-
- @pytest.mark.asyncio
- async def test_openai_shell_call_routes_directly_to_bash(
- self, mock_openai: AsyncOpenAI
- ) -> None:
- """Test OpenAI shell calls stay provider-owned until execution."""
- tools = [
- types.Tool(
- name="bash",
- description="Execute shell commands",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False)
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- tool_call = agent._extract_tool_call(
- SimpleNamespace(
- type="shell_call",
- action=SimpleNamespace(
- to_dict=lambda: {"commands": ["pwd", "ls"], "timeout_ms": 5000}
- ),
- call_id="call_1",
- )
- )
-
- assert tool_call == MCPToolCall(
- name="shell",
- arguments={"commands": ["pwd", "ls"], "timeout_ms": 5000},
- id="call_1",
- )
-
- results = await agent.call_tools(tool_call)
- assert [(call.name, call.arguments) for call in ctx.calls] == [
- ("bash", {"command": "pwd", "timeout_seconds": 5.0}),
- ("bash", {"command": "ls", "timeout_seconds": 5.0}),
- ]
- assert results[0].structuredContent["provider_tool"] == "shell" # type: ignore[index]
-
- @pytest.mark.asyncio
- async def test_computer_tool_conversion(self, mock_openai: AsyncOpenAI) -> None:
- """Test that the agent converts computer capability to OpenAI native format."""
- tools = [
- types.Tool(
- name="computer",
- description="Control computer",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- validate_api_key=False,
- )
-
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- computer_tool = next(
- (t for t in agent._openai_tools if t.get("type") == "computer"),
- None,
- )
- assert computer_tool is not None
- assert agent._tool_name_map["computer"] == "computer"
- assert agent._openai_native_tools["computer"].env_tool_name == "computer"
-
- @pytest.mark.asyncio
- async def test_openai_computer_call_routes_directly_to_generic_computer(
- self, mock_openai: AsyncOpenAI
- ) -> None:
- """Test OpenAI computer calls stay provider-owned until execution."""
- tools = [
- types.Tool(
- name="computer",
- description="Control computer",
- inputSchema={"type": "object"},
- )
- ]
- ctx = MockEvalContext(tools=tools)
-
- async def call_tool(call: Any, /, **kwargs: Any) -> MCPToolResult:
- del kwargs
- ctx.calls.append(call)
- if call.arguments["action"] == "screenshot":
- return MCPToolResult(
- content=[types.ImageContent(type="image", data="img", mimeType="image/png")],
- isError=False,
- )
- return MCPToolResult(
- content=[types.TextContent(type="text", text="clicked")],
- isError=False,
- )
-
- ctx.call_tool = call_tool # type: ignore[method-assign]
- agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False)
- agent.ctx = ctx
- await agent._initialize_from_ctx(ctx)
-
- tool_call = agent._extract_tool_call(
- SimpleNamespace(
- type="computer_call",
- pending_safety_checks=[],
- action=SimpleNamespace(
- to_dict=lambda: {
- "type": "click",
- "x": 10,
- "y": 20,
- "button": "left",
- "keys": ["CTRL"],
- }
- ),
- call_id="call_1",
- )
- )
-
- assert tool_call is not None
- assert tool_call == MCPToolCall(
- name="computer",
- arguments={"type": "click", "x": 10, "y": 20, "button": "left", "keys": ["CTRL"]},
- id="call_1",
- )
-
- results = await agent.call_tools(tool_call)
- assert [(call.name, call.arguments) for call in ctx.calls] == [
- (
- "computer",
- {"action": "click", "x": 10, "y": 20, "button": "left", "hold_keys": ["ctrl"]},
- ),
- ("computer", {"action": "screenshot"}),
- ]
-
- messages = await agent.format_tool_results([tool_call], results)
- assert messages == [
- {
- "type": "computer_call_output",
- "call_id": "call_1",
- "output": {
- "type": "computer_screenshot",
- "image_url": "data:image/png;base64,img",
- "detail": "original",
- },
- }
- ]
-
-
-class TestOpenAICitations:
- """Tests for OpenAI annotation citation extraction."""
-
- @pytest.fixture
- def mock_openai(self) -> AsyncOpenAI:
- client = AsyncOpenAI(api_key="test", base_url="http://localhost")
- client.responses.create = AsyncMock()
- return client
-
- def _make_response(self, output: list[Any]) -> MagicMock:
- response = MagicMock()
- response.id = "resp_1"
- response.output = output
- return response
-
- @pytest.mark.asyncio
- async def test_url_citation_extracted(self, mock_openai: AsyncOpenAI) -> None:
- """url_citation annotations are extracted as citations."""
- from openai.types.responses.response_output_text import AnnotationURLCitation
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- model="gpt-4o",
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._initialized = True
-
- ann = AnnotationURLCitation(
- type="url_citation",
- url="https://example.com/article",
- title="Article",
- start_index=10,
- end_index=25,
- )
- text_block = ResponseOutputText(type="output_text", text="Hello world", annotations=[ann])
- msg_item = ResponseOutputMessage(
- id="msg_1",
- type="message",
- role="assistant",
- content=[text_block],
- status="completed",
- )
- mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item]))
-
- result = await agent.get_response(
- [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}]
- )
-
- assert len(result.citations) == 1
- cit = result.citations[0]
- assert cit["type"] == "url_citation"
- assert cit["source"] == "https://example.com/article"
- assert cit["title"] == "Article"
- assert cit["start_index"] == 10
- assert cit["end_index"] == 25
-
- @pytest.mark.asyncio
- async def test_file_citation_extracted(self, mock_openai: AsyncOpenAI) -> None:
- """file_citation annotations are extracted as citations."""
- from openai.types.responses.response_output_text import AnnotationFileCitation
-
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- model="gpt-4o",
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._initialized = True
-
- ann = AnnotationFileCitation(
- type="file_citation",
- file_id="file-abc123",
- filename="report.pdf",
- index=0,
- )
- text_block = ResponseOutputText(type="output_text", text="Facts", annotations=[ann])
- msg_item = ResponseOutputMessage(
- id="msg_1",
- type="message",
- role="assistant",
- content=[text_block],
- status="completed",
- )
- mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item]))
-
- result = await agent.get_response(
- [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}]
- )
-
- assert len(result.citations) == 1
- cit = result.citations[0]
- assert cit["type"] == "file_citation"
- assert cit["source"] == "file-abc123"
- assert cit["title"] == "report.pdf"
-
- @pytest.mark.asyncio
- async def test_no_annotations_no_citations(self, mock_openai: AsyncOpenAI) -> None:
- """No citations when annotations list is empty."""
- agent = OpenAIAgent.create(
- model_client=mock_openai,
- model="gpt-4o",
- validate_api_key=False,
- )
- agent._openai_tools = []
- agent._initialized = True
-
- text_block = ResponseOutputText(type="output_text", text="Plain answer", annotations=[])
- msg_item = ResponseOutputMessage(
- id="msg_1",
- type="message",
- role="assistant",
- content=[text_block],
- status="completed",
- )
- mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item]))
-
- result = await agent.get_response(
- [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}]
- )
-
- assert result.citations == []
diff --git a/hud/agents/tests/test_openai_compatible.py b/hud/agents/tests/test_openai_compatible.py
deleted file mode 100644
index 77aaa2d04..000000000
--- a/hud/agents/tests/test_openai_compatible.py
+++ /dev/null
@@ -1,300 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, cast
-
-import mcp.types as types
-import pytest
-
-from hud.agents.openai_compatible import OpenAIChatAgent
-from hud.agents.openai_compatible.tools import openai_compatible_tools
-from hud.agents.openai_compatible.tools.computer import (
- GLMComputerTool,
- QwenComputerTool,
- _fix_glm_xml_args,
- _parse_glm_box,
-)
-from hud.agents.openai_compatible.tools.filesystem import ReadTool
-from hud.agents.tools import EnvironmentCapability
-from hud.types import MCPToolCall, MCPToolResult
-
-if TYPE_CHECKING:
- from openai.types.chat import ChatCompletionToolParam
-
-
-def computer_tool(name: str = "computer") -> types.Tool:
- return types.Tool(
- name=name,
- description="Control computer with mouse, keyboard, and screenshots",
- inputSchema={
- "type": "object",
- "properties": {
- "action": {"type": "string"},
- "x": {"type": "integer"},
- "y": {"type": "integer"},
- },
- "required": ["action"],
- },
- _meta={"resolution": {"width": 1024, "height": 768}},
- )
-
-
-def capability(tool: types.Tool) -> EnvironmentCapability:
- return EnvironmentCapability(name="computer", tool_name=tool.name, tool=tool)
-
-
-def filesystem_tool(name: str) -> types.Tool:
- return types.Tool(
- name=name,
- description=f"{name} environment tool",
- inputSchema={"type": "object", "properties": {}},
- )
-
-
-def filesystem_capability(tool_name: str = "read") -> EnvironmentCapability:
- tool = filesystem_tool(tool_name)
- return EnvironmentCapability(
- name="filesystem",
- tool_name=tool.name,
- tool=tool,
- metadata={"tools": {"read": "read", "grep": "grep", "glob": "glob", "list": "list"}},
- )
-
-
-def test_openai_compatible_agent_uses_glm_computer_tool() -> None:
- agent = OpenAIChatAgent.create(
- model="glm-4.6v",
- api_key="test-key",
- base_url="http://example.com/v1",
- )
- tool = computer_tool()
- agent._available_tools = [tool]
- agent._categorized_tools = agent.categorize_tools([tool])
- agent._initialized = True
- agent._on_tools_ready()
-
- schemas = agent.get_tool_schemas()
- schema = cast("dict[str, Any]", schemas[0])
-
- assert schema["type"] == "function"
- assert schema["function"]["name"] == "computer"
- assert len(schemas) == 1
- assert "computer" in agent._openai_compatible_native_tools
- actions = schema["function"]["parameters"]["properties"]["action"]["enum"]
- assert "DONE" not in actions
- assert "FAIL" not in actions
-
-
-def test_openai_compatible_agent_uses_qwen_computer_tool() -> None:
- agent = OpenAIChatAgent.create(
- model="qwen2.5-vl",
- api_key="test-key",
- base_url="http://example.com/v1",
- )
- tool = computer_tool()
- agent._available_tools = [tool]
- agent._categorized_tools = agent.categorize_tools([tool])
- agent._initialized = True
- agent._on_tools_ready()
-
- schemas = agent.get_tool_schemas()
- schema = cast("dict[str, Any]", schemas[0])
-
- assert schema["type"] == "computer_use"
- assert schema["name"] == "computer_use"
- assert len(schemas) == 1
- assert "computer_use" in agent._openai_compatible_native_tools
- actions = schema["parameters"]["properties"]["action"]["enum"]
- assert "terminate" not in actions
- assert "answer" not in actions
-
-
-def test_openai_compatible_registry_ignores_legacy_native_metadata() -> None:
- tool = types.Tool(
- name="glm_computer",
- description="legacy GLM computer",
- inputSchema={"type": "object", "properties": {}},
- _meta={
- "native_tools": {
- "openai_compatible": {
- "api_type": "gui_agent_glm45v",
- "api_name": "computer",
- "role": "computer",
- }
- }
- },
- )
- agent = OpenAIChatAgent.create(
- model="glm-4.6v",
- api_key="test-key",
- base_url="http://example.com/v1",
- )
-
- categorized = agent.categorize_tools([tool])
-
- assert categorized.generic == [tool]
- assert categorized.skipped == []
-
-
-def test_openai_compatible_agent_uses_filesystem_tool_shapes() -> None:
- agent = OpenAIChatAgent.create(
- model="gpt-4o",
- api_key="test-key",
- base_url="http://example.com/v1",
- )
- tools = [filesystem_tool(name) for name in ("read", "grep", "glob", "list")]
- agent._available_tools = tools
- agent._categorized_tools = agent.categorize_tools(tools)
- agent._initialized = True
- agent._on_tools_ready()
-
- schemas = agent.get_tool_schemas()
- function_schemas = [cast("ChatCompletionToolParam", schema) for schema in schemas]
-
- assert [schema["function"]["name"] for schema in function_schemas] == [
- "read",
- "grep",
- "glob",
- "list",
- ]
- assert len(schemas) == 4
- assert set(agent._openai_compatible_backing_tools) == {"read", "grep", "glob", "list"}
- filesystem = agent._environment_capabilities["filesystem"]
- assert filesystem.metadata["tools"] == {
- "read": "read",
- "grep": "grep",
- "glob": "glob",
- "list": "list",
- }
-
-
-def test_openai_compatible_registry_maps_filesystem_capability_to_read_tool() -> None:
- tool = openai_compatible_tools.tool_for_capability(
- filesystem_capability(),
- "gpt-4o",
- )
-
- assert isinstance(tool, ReadTool)
- assert tool.to_params()["function"]["name"] == "read"
-
-
-def test_parse_glm_box() -> None:
- assert _parse_glm_box("[513,438]") == (513, 438)
- assert _parse_glm_box("513, 438") == (513, 438)
- assert _parse_glm_box([513, 438]) == (513, 438)
- assert _parse_glm_box([[513, 438]]) == (513, 438)
- assert _parse_glm_box("bad") is None
-
-
-def test_fix_glm_xml_args() -> None:
- result = _fix_glm_xml_args(
- {"action": "left_click\nstart_box\n[114, 167]"}
- )
-
- assert result == {"action": "left_click", "start_box": "[114, 167]"}
-
-
-@pytest.mark.asyncio
-async def test_glm_computer_translates_to_environment_calls() -> None:
- tool = GLMComputerTool.from_capability(
- capability(computer_tool()),
- GLMComputerTool.default_spec("glm-4.6v"), # type: ignore[arg-type]
- "glm-4.6v",
- )
- calls: list[MCPToolCall] = []
-
- async def caller(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(content=[], isError=False)
-
- await tool.execute(caller, {"action": "left_click", "start_box": "[500,300]"})
-
- assert calls[0].name == "computer"
- assert calls[0].arguments == {
- "action": "click",
- "x": 512,
- "y": 230,
- "button": "left",
- }
- assert calls[1].arguments == {"action": "screenshot"}
-
-
-@pytest.mark.asyncio
-async def test_qwen_computer_translates_to_environment_calls() -> None:
- tool = QwenComputerTool.from_capability(
- capability(computer_tool()),
- QwenComputerTool.default_spec("qwen2.5-vl"), # type: ignore[arg-type]
- "qwen2.5-vl",
- )
- calls: list[MCPToolCall] = []
-
- async def caller(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(content=[], isError=False)
-
- await tool.execute(caller, {"action": "scroll", "coordinate": [100, 200], "pixels": 50})
-
- assert calls[0].name == "computer"
- assert calls[0].arguments == {
- "action": "scroll",
- "x": 100,
- "y": 200,
- "scroll_y": -50,
- }
- assert calls[1].arguments == {"action": "screenshot"}
-
-
-@pytest.mark.asyncio
-async def test_qwen_left_click_drag_uses_mouse_drag_sequence() -> None:
- tool = QwenComputerTool.from_capability(
- capability(computer_tool()),
- QwenComputerTool.default_spec("qwen2.5-vl"), # type: ignore[arg-type]
- "qwen2.5-vl",
- )
- calls: list[MCPToolCall] = []
-
- async def caller(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(content=[], isError=False)
-
- await tool.execute(caller, {"action": "left_click_drag", "coordinate": [300, 400]})
-
- assert [call.name for call in calls] == ["computer", "computer", "computer", "computer"]
- assert [call.arguments for call in calls] == [
- {"action": "mouse_down", "button": "left"},
- {"action": "move", "x": 300, "y": 400},
- {"action": "mouse_up", "button": "left"},
- {"action": "screenshot"},
- ]
-
-
-@pytest.mark.asyncio
-async def test_openai_compatible_filesystem_tool_forwards_to_environment_tool() -> None:
- tool = ReadTool.from_capability(
- filesystem_capability(),
- ReadTool.default_spec("gpt-4o"),
- "gpt-4o",
- )
- calls: list[MCPToolCall] = []
-
- async def caller(call: MCPToolCall) -> MCPToolResult:
- calls.append(call)
- return MCPToolResult(content=[], isError=False)
-
- await tool.execute(caller, {"filePath": "/workspace/app.py", "offset": 10, "limit": 5})
-
- assert len(calls) == 1
- assert calls[0].name == "read"
- assert calls[0].arguments == {"filePath": "/workspace/app.py", "offset": 10, "limit": 5}
-
-
-def test_openai_compatible_tool_registry_selects_model_specific_tool() -> None:
- tool = computer_tool()
- cap = capability(tool)
-
- glm_tool = openai_compatible_tools.tool_for_capability(cap, "glm-4.6v")
- qwen_tool = openai_compatible_tools.tool_for_capability(cap, "qwen2.5-vl")
- unsupported = openai_compatible_tools.tool_for_capability(cap, "llama")
-
- assert isinstance(glm_tool, GLMComputerTool)
- assert isinstance(qwen_tool, QwenComputerTool)
- assert unsupported is None
diff --git a/hud/agents/tests/test_provider_claude_messages.py b/hud/agents/tests/test_provider_claude_messages.py
new file mode 100644
index 000000000..be7fb162b
--- /dev/null
+++ b/hud/agents/tests/test_provider_claude_messages.py
@@ -0,0 +1,257 @@
+"""Claude agent tests."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import MagicMock
+
+import pytest
+
+from hud.agents.base import AgentContext
+from hud.agents.claude import ClaudeAgent
+from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result
+
+
+class Stream:
+ def __init__(self, response: MagicMock) -> None:
+ self.response = response
+
+ async def __aenter__(self) -> Stream:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: Any,
+ ) -> bool:
+ return False
+
+ def __aiter__(self) -> Stream:
+ return self
+
+ async def __anext__(self) -> None:
+ raise StopAsyncIteration
+
+ async def get_final_message(self) -> MagicMock:
+ return self.response
+
+
+class ErrorStream:
+ def __init__(self, error: Exception) -> None:
+ self.error = error
+
+ async def __aenter__(self) -> ErrorStream:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: Any,
+ ) -> bool:
+ return False
+
+ def __aiter__(self) -> ErrorStream:
+ return self
+
+ async def __anext__(self) -> None:
+ raise self.error
+
+
+def _tool_use(name: str, arguments: dict[str, object]) -> MagicMock:
+ block = MagicMock()
+ block.type = "tool_use"
+ block.id = "call_1"
+ block.name = name
+ block.input = arguments
+ return block
+
+
+def _text_block(text: str, *, thinking: bool = False) -> MagicMock:
+ block = MagicMock()
+ block.type = "thinking" if thinking else "text"
+ block.text = text
+ block.thinking = text
+ block.citations = None
+ return block
+
+
+def _message(*blocks: MagicMock) -> MagicMock:
+ response = MagicMock()
+ response.content = list(blocks)
+ return response
+
+
+@pytest.mark.asyncio
+async def test_claude_run_executes_model_tool_call_and_returns_final_answer() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(
+ stream=MagicMock(
+ side_effect=[
+ Stream(_message(_tool_use("lookup", {"query": "hud"}))),
+ Stream(_message(_text_block("final answer"))),
+ ]
+ )
+ )
+ )
+ )
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("tool result")},
+ )
+ agent = ClaudeAgent.create(model_client=client, validate_api_key=False)
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client)
+ )
+
+ assert result.content == "final answer"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "hud"})
+ ]
+ assert client.beta.messages.stream.call_count == 2
+ second_messages = client.beta.messages.stream.call_args_list[1].kwargs["messages"]
+ assert second_messages[-1]["role"] == "user"
+ assert second_messages[-1]["content"][0]["type"] == "tool_result"
+
+
+@pytest.mark.asyncio
+async def test_claude_retries_streamed_invalid_tool_json_once() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(
+ stream=MagicMock(
+ side_effect=[
+ ErrorStream(
+ ValueError("Unable to parse tool parameter JSON from model. JSON: {bad")
+ ),
+ Stream(_message(_text_block("ok"))),
+ ]
+ )
+ )
+ )
+ )
+ agent = ClaudeAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response(
+ [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
+ )
+
+ assert response.content == "ok"
+ assert response.done is True
+ assert client.beta.messages.stream.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_claude_second_invalid_json_retry_adds_guidance_message() -> None:
+ invalid_json_error = ValueError("Unable to parse tool parameter JSON from model. JSON: {bad")
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(
+ stream=MagicMock(
+ side_effect=[
+ ErrorStream(invalid_json_error),
+ ErrorStream(invalid_json_error),
+ Stream(_message(_text_block("ok"))),
+ ]
+ )
+ )
+ )
+ )
+ agent = ClaudeAgent.create(model_client=client, validate_api_key=False)
+ messages = [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
+
+ response = await agent.get_response(cast("Any", messages))
+
+ assert response.content == "ok"
+ assert client.beta.messages.stream.call_count == 3
+ retry_messages = client.beta.messages.stream.call_args_list[2].kwargs["messages"]
+ retry_text = retry_messages[-1]["content"][0]["text"]
+ assert "INVALID_JSON" in retry_text
+ assert "Retry the same intended tool call" in retry_text
+
+
+@pytest.mark.asyncio
+async def test_claude_response_preserves_thinking_as_reasoning() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(
+ stream=MagicMock(
+ return_value=Stream(
+ _message(_text_block("answer"), _text_block("plan", thinking=True))
+ )
+ )
+ )
+ )
+ )
+ agent = ClaudeAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response(
+ [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
+ )
+
+ assert response.content == "answer"
+ assert response.reasoning == "plan"
+
+
+@pytest.mark.asyncio
+async def test_claude_extracts_document_citations_from_text_blocks() -> None:
+ citation = MagicMock()
+ citation.type = "char_location"
+ citation.cited_text = "Revenue"
+ citation.document_index = 0
+ citation.document_title = "financials.pdf"
+ citation.start_char_index = 0
+ citation.end_char_index = 7
+ text_block = _text_block("Revenue")
+ text_block.citations = [citation]
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(stream=MagicMock(return_value=Stream(_message(text_block))))
+ )
+ )
+ agent = ClaudeAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response(
+ [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
+ )
+
+ assert response.citations == [
+ {
+ "type": "document_citation",
+ "text": "Revenue",
+ "source": "0",
+ "title": "financials.pdf",
+ "start_index": 0,
+ "end_index": 7,
+ }
+ ]
+
+
+@pytest.mark.asyncio
+async def test_claude_native_computer_requests_required_beta_header() -> None:
+ client = SimpleNamespace(
+ beta=SimpleNamespace(
+ messages=SimpleNamespace(
+ stream=MagicMock(return_value=Stream(_message(_text_block("answer"))))
+ )
+ )
+ )
+ agent = ClaudeAgent.create(
+ model="claude-sonnet-4-6",
+ model_client=client,
+ validate_api_key=False,
+ )
+ agent.tools.prepare(model=agent.config.model, tools=[mcp_tool("computer")])
+
+ response = await agent.get_response(
+ [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
+ )
+
+ assert response.content == "answer"
+ kwargs = client.beta.messages.stream.call_args.kwargs
+ assert "computer-use-2025-11-24" in kwargs["betas"]
+ assert kwargs["tool_choice"] == {"type": "auto", "disable_parallel_tool_use": True}
diff --git a/hud/agents/tests/test_provider_computer_tools.py b/hud/agents/tests/test_provider_computer_tools.py
new file mode 100644
index 000000000..5504382e6
--- /dev/null
+++ b/hud/agents/tests/test_provider_computer_tools.py
@@ -0,0 +1,226 @@
+"""Computer tool contracts shared across provider adapters."""
+
+from __future__ import annotations
+
+from typing import Any, cast
+
+import pytest
+from mcp import types
+
+from hud.agents.gemini.tools.computer import (
+ GEMINI_COMPUTER_SPEC,
+ GEMINI_SAFETY_BLOCKED_PREFIX,
+ GEMINI_URL_PREFIX,
+ GeminiComputerTool,
+)
+from hud.agents.openai.tools.computer import OpenAIComputerTool
+from hud.agents.openai_compatible.tools.glm_computer import GLM_COMPUTER_SPEC, GLMComputerTool
+from hud.agents.openai_compatible.tools.qwen_computer import (
+ QWEN_COMPUTER_SPEC,
+ QwenComputerTool,
+)
+from hud.agents.tests.conftest import RecordingToolEnvironment, text_result
+from hud.agents.tools.computer import execute_computer_calls
+from hud.types import MCPToolCall, MCPToolResult
+
+
+def _image_result(data: str = "screenshot") -> MCPToolResult:
+ return MCPToolResult(
+ content=[types.ImageContent(type="image", data=data, mimeType="image/png")],
+ isError=False,
+ )
+
+
+@pytest.mark.asyncio
+async def test_shared_computer_execution_appends_screenshot_when_required() -> None:
+ calls: list[MCPToolCall] = []
+
+ async def call_tool(call: MCPToolCall) -> MCPToolResult:
+ calls.append(call)
+ if (call.arguments or {}).get("action") == "screenshot":
+ return _image_result("after")
+ return text_result("clicked")
+
+ result = await execute_computer_calls(
+ call_tool,
+ env_tool_name="computer",
+ calls=[{"action": "click", "x": 1, "y": 2}],
+ ensure_screenshot=True,
+ )
+
+ assert [(call.name, call.arguments) for call in calls] == [
+ ("computer", {"action": "click", "x": 1, "y": 2}),
+ ("computer", {"action": "screenshot"}),
+ ]
+ assert [type(block).__name__ for block in result.content] == ["TextContent", "ImageContent"]
+
+
+@pytest.mark.asyncio
+async def test_openai_computer_translates_actions_and_requires_final_screenshot() -> None:
+ spec = OpenAIComputerTool.default_spec("gpt-5.4")
+ assert spec is not None
+ tool = OpenAIComputerTool(env_tool_name="computer", spec=spec)
+ calls: list[MCPToolCall] = []
+
+ async def call_tool(call: MCPToolCall) -> MCPToolResult:
+ calls.append(call)
+ if (call.arguments or {}).get("action") == "screenshot":
+ return _image_result("after")
+ return text_result("acted")
+
+ result = await tool.execute(
+ call_tool,
+ {"type": "click", "x": 10, "y": 20, "button": "wheel", "keys": ["ctrl"]},
+ )
+
+ assert result.content == [
+ types.TextContent(type="text", text="acted"),
+ types.ImageContent(type="image", data="after", mimeType="image/png"),
+ ]
+ assert [(call.name, call.arguments) for call in calls] == [
+ (
+ "computer",
+ {
+ "action": "click",
+ "x": 10,
+ "y": 20,
+ "button": "middle",
+ "hold_keys": ["ctrl"],
+ },
+ ),
+ ("computer", {"action": "screenshot"}),
+ ]
+
+
+def test_openai_computer_formats_screenshot_for_provider_continuation() -> None:
+ spec = OpenAIComputerTool.default_spec("gpt-5.4")
+ assert spec is not None
+ tool = OpenAIComputerTool(env_tool_name="computer", spec=spec)
+
+ formatted = tool.format_result(
+ MCPToolCall(name="computer", id="call_1", arguments={}),
+ _image_result("after"),
+ )
+
+ output = cast("dict[str, Any]", formatted)
+ assert output["type"] == "computer_call_output"
+ assert output["call_id"] == "call_1"
+ assert output["output"] == {
+ "type": "computer_screenshot",
+ "image_url": "data:image/png;base64,after",
+ "detail": "original",
+ }
+
+
+def test_openai_computer_rejects_provider_continuation_without_screenshot() -> None:
+ spec = OpenAIComputerTool.default_spec("gpt-5.4")
+ assert spec is not None
+ tool = OpenAIComputerTool(env_tool_name="computer", spec=spec)
+
+ with pytest.raises(ValueError, match="missing screenshot"):
+ tool.format_result(
+ MCPToolCall(name="computer", id="call_1", arguments={}),
+ text_result("no screenshot"),
+ )
+
+
+@pytest.mark.asyncio
+async def test_gemini_computer_blocks_unconfirmed_safety_decision_without_environment_call() -> (
+ None
+):
+ tool = GeminiComputerTool(env_tool_name="computer", spec=GEMINI_COMPUTER_SPEC)
+ environment = RecordingToolEnvironment()
+
+ result = await tool.execute(
+ environment.call_tool,
+ {
+ "action": "click_at",
+ "safety_decision": {"decision": "require_confirmation"},
+ },
+ )
+
+ assert environment.calls == []
+ assert result.isError is False
+ assert result.content == [
+ types.TextContent(
+ type="text",
+ text=(
+ f"{GEMINI_SAFETY_BLOCKED_PREFIX}"
+ "Gemini Computer Use action requires user confirmation before execution."
+ ),
+ )
+ ]
+
+
+def test_gemini_computer_formats_url_safety_and_inline_screenshot_parts() -> None:
+ tool = GeminiComputerTool(env_tool_name="computer", spec=GEMINI_COMPUTER_SPEC)
+
+ content = tool.format_result(
+ MCPToolCall(
+ name="computer_use",
+ provider_name="click_at",
+ arguments={"safety_decision": {"decision": "allow"}},
+ ),
+ MCPToolResult(
+ content=[
+ types.TextContent(type="text", text="clicked"),
+ types.TextContent(type="text", text=f"{GEMINI_URL_PREFIX}https://example.com"),
+ types.ImageContent(type="image", data="YWJj", mimeType="image/png"),
+ ],
+ isError=False,
+ ),
+ )
+
+ parts = content.parts or []
+ response = parts[0].function_response
+ assert response is not None
+ assert response.name == "click_at"
+ assert response.response == {
+ "success": True,
+ "output": "clicked",
+ "url": "https://example.com",
+ "safety_acknowledgement": True,
+ }
+ response_parts = response.parts or []
+ assert response_parts[0].inline_data is not None
+ assert response_parts[0].inline_data.data == b"abc"
+
+
+@pytest.mark.asyncio
+async def test_glm_computer_scales_normalized_click_coordinates() -> None:
+ tool = GLMComputerTool(
+ env_tool_name="computer",
+ spec=GLM_COMPUTER_SPEC,
+ display_width=1000,
+ display_height=500,
+ coordinate_space=None,
+ )
+ environment = RecordingToolEnvironment(results={"computer": text_result("ok")})
+
+ await tool.execute(
+ environment.call_tool,
+ {"action": "left_click", "start_box": "[999,999]"},
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("computer", {"action": "click", "x": 999, "y": 499, "button": "left"}),
+ ("computer", {"action": "screenshot"}),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_qwen_computer_translates_wait_seconds_to_milliseconds() -> None:
+ tool = QwenComputerTool(
+ env_tool_name="computer",
+ spec=QWEN_COMPUTER_SPEC,
+ display_width=1000,
+ display_height=500,
+ description="computer",
+ )
+ environment = RecordingToolEnvironment(results={"computer": text_result("waited")})
+
+ await tool.execute(environment.call_tool, {"action": "wait", "time": 1.5})
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("computer", {"action": "wait", "time": 1500})
+ ]
diff --git a/hud/agents/tests/test_provider_gemini_generate_content.py b/hud/agents/tests/test_provider_gemini_generate_content.py
new file mode 100644
index 000000000..524072625
--- /dev/null
+++ b/hud/agents/tests/test_provider_gemini_generate_content.py
@@ -0,0 +1,154 @@
+"""Gemini agent tests."""
+
+from __future__ import annotations
+
+from typing import cast
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+from google.genai import types as genai_types
+
+from hud.agents.base import AgentContext
+from hud.agents.gemini import GeminiAgent
+from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result
+
+
+def _gemini_response(*parts: genai_types.Part) -> genai_types.GenerateContentResponse:
+ return genai_types.GenerateContentResponse(
+ candidates=[
+ genai_types.Candidate(
+ content=genai_types.Content(
+ role="model",
+ parts=list(parts),
+ )
+ )
+ ]
+ )
+
+
+def _gemini_client(*responses: genai_types.GenerateContentResponse) -> MagicMock:
+ client = MagicMock()
+ client.aio = MagicMock()
+ client.aio.models = MagicMock()
+ client.aio.models.generate_content = AsyncMock(side_effect=list(responses))
+ return client
+
+
+@pytest.mark.asyncio
+async def test_gemini_run_executes_model_tool_call_and_returns_final_answer() -> None:
+ client = _gemini_client(
+ _gemini_response(
+ genai_types.Part(
+ function_call=genai_types.FunctionCall(
+ name="lookup",
+ args={"query": "hud"},
+ )
+ )
+ ),
+ _gemini_response(genai_types.Part(text="final answer")),
+ )
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("tool result")},
+ )
+ agent = GeminiAgent.create(model_client=client, validate_api_key=False)
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client)
+ )
+
+ assert result.content == "final answer"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "hud"})
+ ]
+ assert client.aio.models.generate_content.await_count == 2
+ second_contents = cast(
+ "list[genai_types.Content]",
+ client.aio.models.generate_content.await_args_list[1].kwargs["contents"],
+ )
+ function_response_names: list[str] = []
+ for content in second_contents:
+ for part in content.parts or []:
+ function_response = part.function_response
+ if function_response is not None:
+ function_response_names.append(function_response.name or "")
+ assert "lookup" in function_response_names
+
+
+@pytest.mark.asyncio
+async def test_gemini_no_candidates_is_a_user_visible_error() -> None:
+ client = _gemini_client(genai_types.GenerateContentResponse(candidates=[]))
+ agent = GeminiAgent.create(model_client=client, validate_api_key=False)
+
+ with pytest.raises(RuntimeError, match="returned no candidates"):
+ await agent.get_response([])
+
+
+@pytest.mark.asyncio
+async def test_gemini_citations_enable_google_search_at_provider_boundary() -> None:
+ client = _gemini_client(_gemini_response(genai_types.Part(text="answer")))
+ agent = GeminiAgent.create(model_client=client, validate_api_key=False)
+ agent.enable_citations = True
+
+ response = await agent.get_response([])
+
+ assert response.content == "answer"
+ config = client.aio.models.generate_content.await_args.kwargs["config"]
+ assert any(tool.google_search is not None for tool in config.tools)
+
+
+@pytest.mark.asyncio
+async def test_gemini_preserves_thought_parts_as_reasoning() -> None:
+ client = _gemini_client(
+ _gemini_response(
+ genai_types.Part(text="private reasoning", thought=True),
+ genai_types.Part(text="answer"),
+ )
+ )
+ agent = GeminiAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response([])
+
+ assert response.content == "answer"
+ assert response.reasoning == "private reasoning"
+
+
+@pytest.mark.asyncio
+async def test_gemini_prunes_older_computer_screenshots_before_request() -> None:
+ def computer_response(name: str) -> genai_types.FunctionResponse:
+ return genai_types.FunctionResponse(
+ name=name,
+ response={"success": True},
+ parts=[
+ genai_types.FunctionResponsePart(
+ inline_data=genai_types.FunctionResponseBlob(
+ mime_type="image/png",
+ data=b"image-bytes",
+ )
+ )
+ ],
+ )
+
+ old_response = computer_response("click_at")
+ recent_response = computer_response("navigate")
+ messages = [
+ genai_types.Content(
+ role="user",
+ parts=[genai_types.Part(function_response=old_response)],
+ ),
+ genai_types.Content(
+ role="user",
+ parts=[genai_types.Part(function_response=recent_response)],
+ ),
+ ]
+ client = _gemini_client(_gemini_response(genai_types.Part(text="answer")))
+ agent = GeminiAgent.create(model_client=client, validate_api_key=False)
+ agent.max_recent_turn_with_screenshots = 1
+
+ response = await agent.get_response(messages)
+
+ assert response.content == "answer"
+ assert old_response.parts is None
+ assert recent_response.parts is not None
+ requested_contents = client.aio.models.generate_content.await_args.kwargs["contents"]
+ assert requested_contents is messages
diff --git a/hud/agents/tests/test_provider_native_tools.py b/hud/agents/tests/test_provider_native_tools.py
new file mode 100644
index 000000000..866b66851
--- /dev/null
+++ b/hud/agents/tests/test_provider_native_tools.py
@@ -0,0 +1,147 @@
+"""Native provider tool contracts for translation and model gating."""
+
+from __future__ import annotations
+
+import hashlib
+from typing import Any, cast
+
+import pytest
+
+from hud.agents.claude.tools.coding import ClaudeBashTool, ClaudeTextEditorTool
+from hud.agents.gemini.tools.coding import GeminiShellTool
+from hud.agents.gemini.tools.filesystem import GeminiReadTool
+from hud.agents.gemini.tools.memory import GeminiMemoryTool
+from hud.agents.openai.tools.coding import OpenAIShellTool
+from hud.agents.tests.conftest import RecordingToolEnvironment, text_result
+from hud.types import MCPToolCall
+
+
+@pytest.mark.asyncio
+async def test_openai_shell_translates_commands_timeout_and_structured_output() -> None:
+ spec = OpenAIShellTool.default_spec("gpt-5.4")
+ assert spec is not None
+ tool = OpenAIShellTool(env_tool_name="bash", spec=spec)
+ environment = RecordingToolEnvironment(
+ results={
+ "bash": text_result("pwd output"),
+ },
+ )
+
+ result = await tool.execute(
+ environment.call_tool,
+ {"commands": ["pwd"], "timeout_ms": 2500, "max_output_length": 80},
+ )
+ formatted = tool.format_result(MCPToolCall(name="shell", id="call_1", arguments={}), result)
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("bash", {"command": "pwd", "timeout_seconds": 2.5})
+ ]
+ assert result.structuredContent == {
+ "provider_tool": "shell",
+ "output": [
+ {"stdout": "pwd output", "stderr": "", "outcome": {"type": "exit", "exit_code": 0}}
+ ],
+ "max_output_length": 80,
+ }
+ formatted_dict = cast("dict[str, Any]", formatted)
+ assert formatted_dict["type"] == "shell_call_output"
+ assert formatted_dict["call_id"] == "call_1"
+ assert formatted_dict["max_output_length"] == 80
+
+
+@pytest.mark.asyncio
+async def test_openai_shell_rejects_invalid_commands_without_environment_call() -> None:
+ spec = OpenAIShellTool.default_spec("gpt-5.4")
+ assert spec is not None
+ tool = OpenAIShellTool(env_tool_name="bash", spec=spec)
+ environment = RecordingToolEnvironment()
+
+ result = await tool.execute(environment.call_tool, {"commands": 123})
+
+ assert result.isError is True
+ assert environment.calls == []
+
+
+@pytest.mark.asyncio
+async def test_claude_text_editor_translates_str_replace_arguments() -> None:
+ spec = ClaudeTextEditorTool.default_spec("claude-sonnet-4-6")
+ assert spec is not None
+ tool = ClaudeTextEditorTool(env_tool_name="edit", spec=spec)
+ environment = RecordingToolEnvironment(results={"edit": text_result("edited")})
+
+ result = await tool.execute(
+ environment.call_tool,
+ {
+ "command": "str_replace",
+ "path": "/tmp/file.txt",
+ "old_str": "old",
+ "new_str": "new",
+ },
+ )
+
+ assert result.isError is False
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ (
+ "edit",
+ {
+ "command": "replace",
+ "path": "/tmp/file.txt",
+ "old_text": "old",
+ "new_text": "new",
+ },
+ )
+ ]
+
+
+@pytest.mark.asyncio
+async def test_gemini_shell_scopes_command_to_directory() -> None:
+ tool = GeminiShellTool(env_tool_name="bash", spec=GeminiShellTool.default_spec("gemini"))
+ environment = RecordingToolEnvironment(results={"bash": text_result("ok")})
+
+ await tool.execute(environment.call_tool, {"command": "ls -la", "dir_path": "/tmp/my dir"})
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("bash", {"command": "cd '/tmp/my dir' && ls -la"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_gemini_read_translates_line_range_to_offset_and_limit() -> None:
+ tool = GeminiReadTool(env_tool_name="read", spec=GeminiReadTool.default_spec("gemini"))
+ environment = RecordingToolEnvironment(results={"read": text_result("lines")})
+
+ await tool.execute(
+ environment.call_tool,
+ {"file_path": "/repo/file.py", "start_line": 3, "end_line": 7},
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("read", {"filePath": "/repo/file.py", "offset": 2, "limit": 5})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_gemini_memory_persists_trimmed_fact_under_stable_path() -> None:
+ tool = GeminiMemoryTool(env_tool_name="edit", spec=GeminiMemoryTool.default_spec("gemini"))
+ environment = RecordingToolEnvironment(results={"edit": text_result("saved")})
+
+ await tool.execute(environment.call_tool, {"fact": " user likes concise tests "})
+
+ digest = hashlib.sha256(b"user likes concise tests").hexdigest()[:12]
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ (
+ "edit",
+ {
+ "command": "create",
+ "path": f"/memories/gemini-{digest}.md",
+ "file_text": "user likes concise tests\n",
+ },
+ )
+ ]
+
+
+def test_native_tool_model_gating_uses_provider_supported_model_contracts() -> None:
+ assert OpenAIShellTool.default_spec("gpt-5.4") is not None
+ assert OpenAIShellTool.default_spec("gpt-4.1") is None
+ assert ClaudeBashTool.default_spec("claude-sonnet-4-6") is not None
+ assert ClaudeBashTool.default_spec("claude-3-5-sonnet") is None
diff --git a/hud/agents/tests/test_provider_openai_compatible_chat.py b/hud/agents/tests/test_provider_openai_compatible_chat.py
new file mode 100644
index 000000000..373c7b4db
--- /dev/null
+++ b/hud/agents/tests/test_provider_openai_compatible_chat.py
@@ -0,0 +1,215 @@
+"""OpenAI-compatible chat agent tests."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import AsyncMock
+
+import pytest
+from openai.types.chat.chat_completion import ChatCompletion
+
+from hud.agents.base import AgentContext
+from hud.agents.openai_compatible import OpenAIChatAgent
+from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result
+
+
+def _chat_completion(message: dict[str, Any], *, finish_reason: str = "stop") -> ChatCompletion:
+ return ChatCompletion.model_validate(
+ {
+ "id": "chatcmpl-test",
+ "object": "chat.completion",
+ "created": 0,
+ "model": "test-model",
+ "choices": [
+ {
+ "index": 0,
+ "finish_reason": finish_reason,
+ "message": message,
+ }
+ ],
+ }
+ )
+
+
+def _client(*responses: ChatCompletion) -> SimpleNamespace:
+ return SimpleNamespace(
+ chat=SimpleNamespace(
+ completions=SimpleNamespace(create=AsyncMock(side_effect=list(responses)))
+ )
+ )
+
+
+def _chat_completion_with_token_ids(
+ message: dict[str, Any],
+ *,
+ prompt_token_ids: list[int],
+ token_ids: list[int],
+) -> ChatCompletion:
+ completion = _chat_completion(message)
+ choice = completion.choices[0]
+ object.__setattr__(choice, "prompt_token_ids", prompt_token_ids)
+ object.__setattr__(choice, "token_ids", token_ids)
+ return completion
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_run_executes_model_tool_call_and_returns_final_answer() -> None:
+ client = _client(
+ _chat_completion(
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "lookup",
+ "arguments": '{"query":"hud"}',
+ },
+ }
+ ],
+ },
+ finish_reason="tool_calls",
+ ),
+ _chat_completion({"role": "assistant", "content": "final answer"}),
+ )
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("tool result")},
+ )
+ agent = OpenAIChatAgent.create(model="test-model", openai_client=client)
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client)
+ )
+
+ assert result.content == "final answer"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "hud"})
+ ]
+ assert client.chat.completions.create.await_count == 2
+ second_messages = client.chat.completions.create.await_args_list[1].kwargs["messages"]
+ assert {
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "tool result",
+ } in second_messages
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_auto_respond_followup_does_not_repeat_system_prompt(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ async def continue_once(content: str | None, *, enabled: bool) -> object:
+ assert enabled is True
+ if content == "need input":
+ return text_prompt("continue")
+ return None
+
+ monkeypatch.setattr("hud.agents.base.auto_respond", continue_once)
+ client = _client(
+ _chat_completion({"role": "assistant", "content": "need input"}),
+ _chat_completion({"role": "assistant", "content": "final answer"}),
+ )
+ agent = OpenAIChatAgent.create(
+ model="test-model",
+ openai_client=client,
+ system_prompt="system rules",
+ auto_respond=True,
+ )
+
+ result = await agent.run(AgentContext(messages=[text_prompt("start")]))
+
+ assert result.content == "final answer"
+ second_messages = client.chat.completions.create.await_args_list[1].kwargs["messages"]
+ system_messages = [message for message in second_messages if message["role"] == "system"]
+ assert system_messages == [{"role": "system", "content": "system rules"}]
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_preserves_reasoning_fields_on_assistant_message() -> None:
+ reasoning_details = [{"type": "reasoning.text", "text": "step"}]
+ client = _client(
+ _chat_completion(
+ {
+ "role": "assistant",
+ "content": "answer",
+ "reasoning": "private reasoning",
+ "reasoning_details": reasoning_details,
+ }
+ )
+ )
+ agent = OpenAIChatAgent.create(model="reasoning-model", openai_client=client)
+ messages: list[dict[str, Any]] = [{"role": "user", "content": "question"}]
+
+ result = await agent.get_response(cast("Any", messages))
+
+ assert result.content == "answer"
+ assert result.reasoning == "private reasoning"
+ assert messages[-1]["reasoning"] == "private reasoning"
+ assert messages[-1]["reasoning_details"] == reasoning_details
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_api_error_returns_error_response() -> None:
+ client = SimpleNamespace(
+ chat=SimpleNamespace(
+ completions=SimpleNamespace(create=AsyncMock(side_effect=RuntimeError("boom")))
+ )
+ )
+ agent = OpenAIChatAgent.create(model="test-model", openai_client=client)
+
+ response = await agent.get_response(cast("Any", [{"role": "user", "content": "question"}]))
+
+ assert response.done is True
+ assert response.isError is True
+ assert response.content == "Error getting response boom"
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_checkpoint_is_sent_in_provider_body() -> None:
+ client = _client(_chat_completion({"role": "assistant", "content": "answer"}))
+ agent = OpenAIChatAgent.create(
+ model="test-model",
+ openai_client=client,
+ checkpoint="checkpoint-123",
+ )
+
+ response = await agent.get_response(cast("Any", [{"role": "user", "content": "question"}]))
+
+ assert response.content == "answer"
+ assert client.chat.completions.create.await_args.kwargs["extra_body"] == {
+ "checkpoint": "checkpoint-123"
+ }
+
+
+@pytest.mark.asyncio
+async def test_openai_compatible_token_continuation_is_sent_after_first_response() -> None:
+ client = _client(
+ _chat_completion_with_token_ids(
+ {"role": "assistant", "content": "first"},
+ prompt_token_ids=[1, 2],
+ token_ids=[3],
+ ),
+ _chat_completion({"role": "assistant", "content": "second"}),
+ )
+ agent = OpenAIChatAgent.create(
+ model="test-model",
+ openai_client=client,
+ completion_kwargs={"extra_body": {"return_token_ids": True}},
+ )
+ messages = cast("Any", [{"role": "user", "content": "question"}])
+
+ first = await agent.get_response(messages)
+ second = await agent.get_response(messages)
+
+ assert first.content == "first"
+ assert second.content == "second"
+ second_body = client.chat.completions.create.await_args_list[1].kwargs["extra_body"]
+ assert second_body == {
+ "return_token_ids": True,
+ "prompt_token_ids": [1, 2, 3],
+ "continuation_from": 2,
+ }
diff --git a/hud/agents/tests/test_provider_openai_responses.py b/hud/agents/tests/test_provider_openai_responses.py
new file mode 100644
index 000000000..5cd82108f
--- /dev/null
+++ b/hud/agents/tests/test_provider_openai_responses.py
@@ -0,0 +1,206 @@
+"""OpenAI Responses agent tests."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import AsyncMock
+
+import pytest
+from openai.types.responses import (
+ ResponseFunctionToolCall,
+ ResponseOutputMessage,
+ ResponseOutputText,
+ ResponseReasoningItem,
+)
+from openai.types.responses.response_reasoning_item import Summary
+
+from hud.agents.base import AgentContext
+from hud.agents.openai import OpenAIAgent
+from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result
+
+
+def _message_response(text: str, *, response_id: str = "resp_final") -> SimpleNamespace:
+ return SimpleNamespace(
+ id=response_id,
+ output=[
+ ResponseOutputMessage(
+ id=f"msg_{response_id}",
+ type="message",
+ role="assistant",
+ status="completed",
+ content=[ResponseOutputText(type="output_text", text=text, annotations=[])],
+ )
+ ],
+ )
+
+
+@pytest.mark.asyncio
+async def test_openai_run_executes_model_tool_call_and_returns_final_answer() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(
+ create=AsyncMock(
+ side_effect=[
+ SimpleNamespace(
+ id="resp_tool",
+ output=[
+ ResponseFunctionToolCall(
+ id="item_1",
+ type="function_call",
+ call_id="call_1",
+ name="lookup",
+ arguments='{"query":"hud"}',
+ )
+ ],
+ ),
+ _message_response("final answer"),
+ ]
+ )
+ )
+ )
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("tool result")},
+ )
+ agent = OpenAIAgent.create(model_client=client, validate_api_key=False)
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client)
+ )
+
+ assert result.content == "final answer"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "hud"})
+ ]
+ assert client.responses.create.await_count == 2
+ second_input = client.responses.create.await_args_list[1].kwargs["input"]
+ assert client.responses.create.await_args_list[1].kwargs["previous_response_id"] == "resp_tool"
+ assert second_input[-1]["type"] == "function_call_output"
+ assert second_input[-1]["call_id"] == "call_1"
+
+
+@pytest.mark.asyncio
+async def test_openai_get_response_preserves_reasoning_and_citations() -> None:
+ text = ResponseOutputText.model_validate(
+ {
+ "type": "output_text",
+ "text": "Example",
+ "annotations": [
+ {
+ "type": "url_citation",
+ "url": "https://example.com",
+ "title": "Example",
+ "start_index": 0,
+ "end_index": 7,
+ }
+ ],
+ }
+ )
+ client = SimpleNamespace(
+ responses=SimpleNamespace(
+ create=AsyncMock(
+ return_value=SimpleNamespace(
+ id="resp",
+ output=[
+ ResponseReasoningItem(
+ id="reason",
+ type="reasoning",
+ summary=[Summary(type="summary_text", text="thought")],
+ ),
+ ResponseOutputMessage(
+ id="msg",
+ type="message",
+ role="assistant",
+ status="completed",
+ content=[text],
+ ),
+ ],
+ )
+ )
+ )
+ )
+ agent = OpenAIAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response([])
+
+ assert response.content == "Example"
+ assert response.reasoning == "thought"
+ assert response.citations == [
+ {
+ "type": "url_citation",
+ "text": "Example",
+ "source": "https://example.com",
+ "title": "Example",
+ "start_index": 0,
+ "end_index": 7,
+ }
+ ]
+
+
+@pytest.mark.asyncio
+async def test_openai_citation_mode_requests_provider_source_metadata() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("answer")))
+ )
+ agent = OpenAIAgent.create(model_client=client, validate_api_key=False)
+ agent.enable_citations = True
+
+ response = await agent.get_response([])
+
+ assert response.content == "answer"
+ assert client.responses.create.await_args.kwargs["include"] == [
+ "web_search_call.action.sources"
+ ]
+
+
+@pytest.mark.asyncio
+async def test_openai_get_response_parses_native_computer_and_shell_calls() -> None:
+ def _action(payload: dict[str, Any]) -> SimpleNamespace:
+ return SimpleNamespace(to_dict=lambda: payload)
+
+ client = SimpleNamespace(
+ responses=SimpleNamespace(
+ create=AsyncMock(
+ return_value=SimpleNamespace(
+ id="resp",
+ output=[
+ SimpleNamespace(
+ type="computer_call",
+ call_id="computer_call_1",
+ actions=[_action({"type": "click", "x": 1, "y": 2})],
+ action=None,
+ pending_safety_checks=[],
+ ),
+ SimpleNamespace(
+ type="shell_call",
+ call_id="shell_call_1",
+ action=_action({"commands": ["pwd"]}),
+ ),
+ ],
+ )
+ )
+ )
+ )
+ agent = OpenAIAgent.create(model_client=client, validate_api_key=False)
+
+ response = await agent.get_response([])
+
+ assert response.done is False
+ assert [(call.name, call.arguments, call.id) for call in response.tool_calls] == [
+ ("computer", {"actions": [{"type": "click", "x": 1, "y": 2}]}, "computer_call_1"),
+ ("shell", {"commands": ["pwd"]}, "shell_call_1"),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_openai_run_returns_error_trace_for_provider_failure() -> None:
+ client = SimpleNamespace(
+ responses=SimpleNamespace(create=AsyncMock(side_effect=RuntimeError("provider down")))
+ )
+ agent = OpenAIAgent.create(model_client=client, validate_api_key=False)
+
+ result = await agent.run(AgentContext(messages=[text_prompt("hello")]))
+
+ assert result.isError is True
+ assert result.content == "provider down"
+ assert result.info["error"] == "provider down"
diff --git a/hud/agents/tests/test_provider_tool_results.py b/hud/agents/tests/test_provider_tool_results.py
new file mode 100644
index 000000000..8ae5f1974
--- /dev/null
+++ b/hud/agents/tests/test_provider_tool_results.py
@@ -0,0 +1,174 @@
+"""Provider continuation contracts for environment tool results."""
+
+from __future__ import annotations
+
+from typing import Any, cast
+
+from mcp import types
+
+from hud.agents.claude.tools.base import ClaudeFunctionTool
+from hud.agents.gemini.tools.base import GeminiFunctionTool
+from hud.agents.openai.tools.base import OpenAIFunctionTool
+from hud.agents.openai_compatible.tools.base import OpenAICompatibleFunctionTool
+from hud.agents.tests.conftest import mcp_tool
+from hud.types import MCPToolCall, MCPToolResult
+
+
+def _text_image_result() -> MCPToolResult:
+ return MCPToolResult(
+ content=[
+ types.TextContent(type="text", text="text output"),
+ types.ImageContent(type="image", data="image-bytes", mimeType="image/png"),
+ ],
+ isError=False,
+ )
+
+
+def test_openai_formats_text_image_structured_and_error_results() -> None:
+ tool = OpenAIFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+ assert tool is not None
+
+ output = tool.format_result(
+ MCPToolCall(name="lookup", id="call_1", arguments={}),
+ MCPToolResult(
+ content=[
+ types.TextContent(type="text", text="failed"),
+ types.ImageContent(type="image", data="image-bytes", mimeType="image/png"),
+ ],
+ isError=True,
+ structuredContent={"code": 500},
+ ),
+ )
+
+ assert output is not None
+ output_dict = cast("dict[str, Any]", output)
+ assert output_dict["type"] == "function_call_output"
+ assert output_dict["call_id"] == "call_1"
+ blocks = cast("list[dict[str, Any]]", output_dict["output"])
+ assert {"type": "input_text", "text": "[tool_error] true"} in blocks
+ assert {"type": "input_text", "text": '{"code": 500}'} in blocks
+ assert {"type": "input_text", "text": "failed"} in blocks
+ assert {
+ "type": "input_image",
+ "image_url": "data:image/png;base64,image-bytes",
+ } in blocks
+
+
+def test_openai_formats_empty_result_as_empty_function_output() -> None:
+ tool = OpenAIFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+ assert tool is not None
+
+ output = tool.format_result(
+ MCPToolCall(name="lookup", id="call_1", arguments={}),
+ MCPToolResult(content=[], isError=False),
+ )
+
+ assert output is not None
+ blocks = cast("list[dict[str, Any]]", cast("dict[str, Any]", output)["output"])
+ assert blocks == [{"type": "input_text", "text": ""}]
+
+
+def test_claude_formats_result_blocks_and_citation_documents() -> None:
+ tool = ClaudeFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+
+ message = tool.format_result(
+ MCPToolCall(
+ name="lookup",
+ id="call_1",
+ arguments={},
+ _meta=types.RequestParams.Meta.model_validate({"enable_citations": True}),
+ ),
+ _text_image_result(),
+ )
+
+ assert message is not None
+ assert message["role"] == "user"
+ content = cast("list[dict[str, Any]]", message["content"])
+ tool_result = content[0]
+ assert tool_result["type"] == "tool_result"
+ assert tool_result["tool_use_id"] == "call_1"
+ assert cast("list[dict[str, Any]]", tool_result["content"]) == [
+ {"type": "text", "text": "text output"},
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/png",
+ "data": "image-bytes",
+ },
+ },
+ ]
+ assert content[1]["type"] == "document"
+ assert content[1]["citations"] == {"enabled": True}
+
+
+def test_claude_formats_errors_as_tool_result_text() -> None:
+ tool = ClaudeFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+
+ message = tool.format_result(
+ MCPToolCall(name="lookup", id="call_1", arguments={}),
+ MCPToolResult(
+ content=[types.TextContent(type="text", text="boom")],
+ isError=True,
+ ),
+ )
+
+ assert message is not None
+ tool_result = cast("list[dict[str, Any]]", message["content"])[0]
+ assert tool_result["content"] == [{"type": "text", "text": "Error: boom"}]
+
+
+def test_gemini_formats_success_and_error_function_responses() -> None:
+ tool = GeminiFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+
+ success = tool.format_result(
+ MCPToolCall(name="lookup", provider_name="provider_lookup", arguments={}),
+ MCPToolResult(
+ content=[types.TextContent(type="text", text="found")],
+ isError=False,
+ ),
+ )
+ error = tool.format_result(
+ MCPToolCall(name="lookup", arguments={}),
+ MCPToolResult(
+ content=[types.TextContent(type="text", text="failed")],
+ isError=True,
+ ),
+ )
+
+ success_parts = success.parts or []
+ error_parts = error.parts or []
+ success_response = success_parts[0].function_response
+ error_response = error_parts[0].function_response
+ assert success_response is not None
+ assert success_response.name == "provider_lookup"
+ assert success_response.response == {"success": True, "output": "found"}
+ assert error_response is not None
+ assert error_response.response == {"error": "failed"}
+
+
+def test_openai_compatible_formats_text_image_and_structured_results() -> None:
+ tool = OpenAICompatibleFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things"))
+
+ image_output = tool.format_result(
+ MCPToolCall(name="lookup", id="call_1", arguments={}),
+ _text_image_result(),
+ )
+ structured_output = tool.format_result(
+ MCPToolCall(name="lookup", id="call_2", arguments={}),
+ MCPToolResult(
+ content=[], isError=False, structuredContent={"result": {"type": "text", "text": "ok"}}
+ ),
+ )
+
+ assert image_output == [
+ {"role": "tool", "tool_call_id": "call_1", "content": "text output"},
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Tool returned the following:"},
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,image-bytes"}},
+ ],
+ },
+ ]
+ assert structured_output == {"role": "tool", "tool_call_id": "call_2", "content": "ok"}
diff --git a/hud/agents/tests/test_resolver.py b/hud/agents/tests/test_resolver.py
deleted file mode 100644
index 05f06b6b7..000000000
--- a/hud/agents/tests/test_resolver.py
+++ /dev/null
@@ -1,276 +0,0 @@
-"""Tests for model resolution and create_agent."""
-
-from __future__ import annotations
-
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from hud.agents import create_agent
-from hud.agents.resolver import resolve_cls
-
-
-@pytest.fixture(autouse=True)
-def clear_cache() -> None:
- """Clear the models cache before each test."""
- import hud.agents.resolver as resolver_module
-
- resolver_module._models_cache = None
-
-
-# Mock API response data matching the platform backend format
-MOCK_MODELS = [
- {
- "id": "uuid-1",
- "name": "Claude Sonnet 4.6",
- "model_name": "claude-sonnet-4-6",
- "sdk_agent_type": None,
- "provider": {"name": "Anthropic", "default_sdk_agent_type": "claude"},
- },
- {
- "id": "uuid-2",
- "name": "GPT 5.4",
- "model_name": "gpt-5.4",
- "sdk_agent_type": None,
- "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"},
- },
- {
- "id": "uuid-3",
- "name": "Operator",
- "model_name": "computer-use-preview",
- "sdk_agent_type": "operator",
- "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"},
- },
- {
- "id": "uuid-4",
- "name": "Gemini 3 Pro",
- "model_name": "gemini-3-pro-preview",
- "sdk_agent_type": None,
- "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"},
- },
- {
- "id": "uuid-5",
- "name": "Gemini 2.5 Computer Use Preview",
- "model_name": "gemini-2.5-computer-use-preview",
- "sdk_agent_type": "gemini_cua",
- "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"},
- },
- {
- "id": "uuid-6",
- "name": "Grok 4.1 Fast",
- "model_name": "grok-4-1-fast",
- "sdk_agent_type": None,
- "provider": {"name": "xAI", "default_sdk_agent_type": "openai_compatible"},
- },
-]
-
-
-class TestResolveCls:
- """Tests for resolve_cls function."""
-
- def test_resolves_known_agent_type(self) -> None:
- """Known AgentType strings resolve to their class."""
- from hud.agents.claude import ClaudeAgent
-
- cls, gateway_info = resolve_cls("claude")
- assert cls == ClaudeAgent
- assert gateway_info is None
-
- def test_resolves_openai(self) -> None:
- """Resolves 'openai' to OpenAIAgent."""
- from hud.agents import OpenAIAgent
-
- cls, _gateway_info = resolve_cls("openai")
- assert cls == OpenAIAgent
-
- def test_resolves_gemini(self) -> None:
- """Resolves 'gemini' to GeminiAgent."""
- from hud.agents.gemini import GeminiAgent
-
- cls, _gateway_info = resolve_cls("gemini")
- assert cls == GeminiAgent
-
- def test_unknown_model_raises(self) -> None:
- """Unknown model raises ValueError."""
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- pytest.raises(ValueError, match="not found"),
- ):
- resolve_cls("unknown-model-xyz-123")
-
- def test_resolves_claude_model(self) -> None:
- """Resolves Claude model to ClaudeAgent via sdk_agent_type."""
- from hud.agents.claude import ClaudeAgent
-
- with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS):
- cls, info = resolve_cls("claude-sonnet-4-6")
- assert cls == ClaudeAgent
- assert info is not None
- assert info["model_name"] == "claude-sonnet-4-6"
-
- def test_resolves_openai_model(self) -> None:
- """Resolves OpenAI model to OpenAIAgent via sdk_agent_type."""
- from hud.agents import OpenAIAgent
-
- with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS):
- cls, info = resolve_cls("gpt-5.4")
- assert cls == OpenAIAgent
- assert info is not None
-
- def test_operator_model_is_not_supported(self) -> None:
- """Stale gateway Operator models fail with a clear message."""
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- pytest.raises(ValueError, match="Operator agent is no longer supported"),
- ):
- resolve_cls("computer-use-preview")
-
- def test_resolves_gemini_model(self) -> None:
- """Resolves Gemini model to GeminiAgent via provider default."""
- from hud.agents.gemini import GeminiAgent
-
- with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS):
- cls, info = resolve_cls("gemini-3-pro-preview")
- assert cls == GeminiAgent
- assert info is not None
-
- def test_gemini_cua_model_is_not_supported(self) -> None:
- """Stale gateway Gemini CUA models fail with a clear message."""
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- pytest.raises(ValueError, match="Gemini CUA agent is no longer supported"),
- ):
- resolve_cls("gemini-2.5-computer-use-preview")
-
- def test_resolves_openai_compatible_model(self) -> None:
- """Resolves OpenAI-compatible model to OpenAIChatAgent via provider default."""
- from hud.agents.openai_compatible import OpenAIChatAgent
-
- with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS):
- cls, info = resolve_cls("grok-4-1-fast")
- assert cls == OpenAIChatAgent
- assert info is not None
-
- def test_unsupported_sdk_agent_type_is_rejected(self) -> None:
- """Unsupported sdk_agent_type values are not silently remapped."""
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- pytest.raises(ValueError, match="Operator agent is no longer supported"),
- ):
- resolve_cls("computer-use-preview")
-
-
-class TestCreateAgent:
- """Tests for create_agent function - gateway-only."""
-
- def test_creates_with_gateway_client(self) -> None:
- """create_agent always uses gateway routing."""
- from hud.agents import OpenAIAgent
-
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- patch.object(OpenAIAgent, "create") as mock_create,
- patch("hud.agents.gateway.build_gateway_client") as mock_build_client,
- ):
- mock_client = MagicMock()
- mock_build_client.return_value = mock_client
- mock_agent = MagicMock()
- mock_create.return_value = mock_agent
-
- agent = create_agent("gpt-5.4")
-
- call_kwargs = mock_create.call_args.kwargs
- assert call_kwargs["model"] == "gpt-5.4"
- assert "model_client" in call_kwargs
- assert agent == mock_agent
-
- def test_passes_kwargs_to_create(self) -> None:
- """Extra kwargs are passed to agent.create()."""
- from hud.agents import OpenAIAgent
-
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- patch.object(OpenAIAgent, "create") as mock_create,
- patch("hud.agents.gateway.build_gateway_client"),
- ):
- mock_create.return_value = MagicMock()
-
- create_agent("gpt-5.4", temperature=0.5, max_tokens=1000)
-
- call_kwargs = mock_create.call_args.kwargs
- assert call_kwargs["temperature"] == 0.5
- assert call_kwargs["max_tokens"] == 1000
-
- def test_known_agent_type_also_uses_gateway(self) -> None:
- """Even 'claude' string uses gateway (it's a gateway shortcut)."""
- from hud.agents.claude import ClaudeAgent
-
- with (
- patch.object(ClaudeAgent, "create") as mock_create,
- patch("hud.agents.gateway.build_gateway_client") as mock_build_client,
- ):
- mock_client = MagicMock()
- mock_build_client.return_value = mock_client
- mock_create.return_value = MagicMock()
-
- create_agent("claude")
-
- mock_build_client.assert_called_once()
- call_kwargs = mock_create.call_args.kwargs
- assert "model_client" in call_kwargs
-
- def test_uses_correct_provider_from_gateway_info(self) -> None:
- """Provider name is extracted from gateway info."""
- from hud.agents.claude import ClaudeAgent
-
- with (
- patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS),
- patch.object(ClaudeAgent, "create") as mock_create,
- patch("hud.agents.gateway.build_gateway_client") as mock_build_client,
- ):
- mock_build_client.return_value = MagicMock()
- mock_create.return_value = MagicMock()
-
- create_agent("claude-sonnet-4-6")
-
- mock_build_client.assert_called_once_with("Anthropic")
-
-
-class TestBuildGatewayClient:
- """Tests for build_gateway_client function."""
-
- def test_builds_anthropic_client(self) -> None:
- """Builds AsyncAnthropic for anthropic provider."""
- from hud.agents.gateway import build_gateway_client
-
- with patch("hud.settings.settings") as mock_settings:
- mock_settings.api_key = "test-key"
- mock_settings.hud_gateway_url = "https://gateway.hud.ai"
-
- with patch("anthropic.AsyncAnthropic") as mock_client_cls:
- build_gateway_client("anthropic")
- mock_client_cls.assert_called_once()
-
- def test_builds_openai_client_for_openai(self) -> None:
- """Builds AsyncOpenAI for openai provider."""
- from hud.agents.gateway import build_gateway_client
-
- with patch("hud.settings.settings") as mock_settings:
- mock_settings.api_key = "test-key"
- mock_settings.hud_gateway_url = "https://gateway.hud.ai"
-
- with patch("openai.AsyncOpenAI") as mock_client_cls:
- build_gateway_client("openai")
- mock_client_cls.assert_called_once()
-
- def test_builds_openai_client_for_unknown(self) -> None:
- """Builds AsyncOpenAI for unknown providers (openai-compatible)."""
- from hud.agents.gateway import build_gateway_client
-
- with patch("hud.settings.settings") as mock_settings:
- mock_settings.api_key = "test-key"
- mock_settings.hud_gateway_url = "https://gateway.hud.ai"
-
- with patch("openai.AsyncOpenAI") as mock_client_cls:
- build_gateway_client("together")
- mock_client_cls.assert_called_once()
diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py
deleted file mode 100644
index c818e3a7b..000000000
--- a/hud/agents/tests/test_run_eval.py
+++ /dev/null
@@ -1,269 +0,0 @@
-"""Tests for MCPAgent.run() with EvalContext."""
-
-from __future__ import annotations
-
-from typing import Any, ClassVar
-
-import pytest
-from mcp import types
-
-from hud.agents import MCPAgent
-from hud.agents.base import BaseCreateParams
-from hud.environment.router import ToolRouter
-from hud.eval.context import EvalContext
-from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult
-
-
-class MockConfig(BaseAgentConfig):
- model_name: str = "MockAgent"
- model: str = "mock-model"
-
-
-class MockCreateParams(BaseCreateParams, MockConfig):
- pass
-
-
-class MockMCPAgent(MCPAgent):
- """Mock agent for testing run()."""
-
- metadata: ClassVar[dict[str, Any] | None] = {}
- config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig
-
- @classmethod
- def agent_type(cls) -> AgentType:
- """Return the AgentType for the mock agent."""
- return AgentType.OPENAI
-
- def __init__(self, **kwargs: Any) -> None:
- params = MockCreateParams(**kwargs)
- super().__init__(params)
- self._response = InferenceResult(content="Test response", tool_calls=[], done=True)
-
- def set_response(self, response: InferenceResult) -> None:
- self._response = response
-
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
- return self._response
-
- async def format_tool_results(
- self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
- ) -> list[dict[str, Any]]:
- return [{"role": "tool", "content": str(r)} for r in tool_results]
-
- async def get_system_messages(self) -> list[Any]:
- return []
-
- async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
- return [{"type": "text", "text": getattr(b, "text")} for b in blocks if hasattr(b, "text")]
-
-
-class MockEvalContext(EvalContext):
- """Mock EvalContext for testing - inherits from real EvalContext."""
-
- def __init__(self, prompt: str = "Test prompt", tools: list[types.Tool] | None = None) -> None:
- # Core attributes
- self.prompt = prompt
- self._tools = tools or [types.Tool(name="test_tool", description="Test", inputSchema={})]
- self._submitted: str | dict[str, Any] | None = None
- self.reward: float | None = None
- self._initialized = True
-
- # Environment attributes
- self._router = ToolRouter()
-
- # EvalContext attributes
- self._task = None
- self.trace_id = "test-trace-id"
- self.eval_name = "test-eval"
- self.job_id: str | None = None
- self.group_id: str | None = None
- self.index = 0
- self.variants: dict[str, Any] = {}
- self.answer: str | dict[str, Any] | None = None
- self.system_prompt: str | None = None
- self.error: BaseException | None = None
- self.metadata: dict[str, Any] = {}
- self.results: list[Any] = []
- self._is_summary = False
-
- def as_tools(self) -> list[types.Tool]:
- return self._tools
-
- @property
- def has_scenario(self) -> bool:
- return True
-
- async def list_tools(self) -> list[types.Tool]:
- return self._tools
-
- async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult:
- # Handle tuple format (name, args)
- if isinstance(call, tuple):
- name = call[0]
- elif hasattr(call, "name"):
- name = call.name
- else:
- name = str(call)
- return MCPToolResult(
- content=[types.TextContent(type="text", text=f"Result from {name}")],
- isError=False,
- )
-
- async def submit(self, answer: str | dict[str, Any]) -> None:
- self._submitted = answer
-
-
-class TestRun:
- """Tests for MCPAgent.run() with EvalContext."""
-
- @pytest.mark.asyncio
- async def test_run_basic(self) -> None:
- """Test basic run() flow."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
-
- result = await agent.run(ctx)
-
- assert result.done
- assert result.content == "Test response"
- assert ctx._submitted == "Test response"
-
- @pytest.mark.asyncio
- async def test_run_no_prompt_raises(self) -> None:
- """Test run() raises when prompt is not set."""
- ctx = MockEvalContext(prompt="")
- agent = MockMCPAgent()
-
- with pytest.raises(ValueError, match="prompt is not set"):
- await agent.run(ctx)
-
- @pytest.mark.asyncio
- async def test_run_wrong_type_raises(self) -> None:
- """Test run() raises TypeError for non-EvalContext."""
- agent = MockMCPAgent()
-
- with pytest.raises(TypeError, match="must be EvalContext"):
- await agent.run("not an eval context") # type: ignore[arg-type]
-
- @pytest.mark.asyncio
- async def test_run_clears_ctx(self) -> None:
- """Test run() clears ctx after completion."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
-
- await agent.run(ctx)
- assert agent.ctx is None
-
- @pytest.mark.asyncio
- async def test_run_no_submit_on_empty_content(self) -> None:
- """Test run() doesn't submit when content is empty."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
- agent.set_response(InferenceResult(content="", tool_calls=[], done=True))
-
- await agent.run(ctx)
- assert ctx._submitted is None
-
- @pytest.mark.asyncio
- async def test_run_initializes_tools(self) -> None:
- """Test run() initializes tools from context."""
- ctx = MockEvalContext(
- prompt="Do the task",
- tools=[
- types.Tool(name="tool1", description="Tool 1", inputSchema={}),
- types.Tool(name="tool2", description="Tool 2", inputSchema={}),
- ],
- )
- agent = MockMCPAgent()
-
- await agent.run(ctx)
-
- assert agent._initialized
- # After cleanup, ctx is None but tools were discovered
-
-
-class TestRunCitations:
- """Tests for citation flow through run() -> Trace -> submit()."""
-
- @pytest.mark.asyncio
- async def test_run_submits_plain_string_without_citations(self) -> None:
- """When no citations, submit() receives a plain string."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
- agent.set_response(InferenceResult(content="answer", done=True))
-
- await agent.run(ctx)
-
- assert ctx._submitted == "answer"
-
- @pytest.mark.asyncio
- async def test_run_submits_dict_with_citations(self) -> None:
- """When citations are present, submit() receives a dict."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
- agent.set_response(
- InferenceResult(
- content="answer with sources",
- done=True,
- citations=[
- {"type": "url_citation", "source": "https://example.com", "title": "Ex"},
- ],
- )
- )
-
- await agent.run(ctx)
-
- assert isinstance(ctx._submitted, dict)
- assert ctx._submitted["content"] == "answer with sources"
- assert len(ctx._submitted["citations"]) == 1
- assert ctx._submitted["citations"][0]["source"] == "https://example.com"
-
- @pytest.mark.asyncio
- async def test_trace_carries_citations_from_inference(self) -> None:
- """Trace.citations is populated from the final InferenceResult."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
- citations = [
- {"type": "grounding", "source": "https://a.com", "text": "fact"},
- {"type": "url_citation", "source": "https://b.com", "title": "B"},
- ]
- agent.set_response(
- InferenceResult(
- content="sourced answer",
- done=True,
- citations=citations,
- )
- )
-
- trace = await agent.run(ctx)
-
- assert len(trace.citations) == 2
- assert trace.citations[0]["source"] == "https://a.com"
- assert trace.citations[1]["source"] == "https://b.com"
-
- @pytest.mark.asyncio
- async def test_trace_empty_citations_on_no_citations(self) -> None:
- """Trace.citations is empty when InferenceResult has no citations."""
- ctx = MockEvalContext(prompt="Do the task")
- agent = MockMCPAgent()
- agent.set_response(InferenceResult(content="plain answer", done=True))
-
- trace = await agent.run(ctx)
-
- assert trace.citations == []
-
- @pytest.mark.asyncio
- async def test_trace_empty_citations_on_error(self) -> None:
- """Trace.citations is empty when agent errors out."""
-
- class FailingAgent(MockMCPAgent):
- async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult:
- raise RuntimeError("boom")
-
- ctx = MockEvalContext(prompt="Do the task")
- agent = FailingAgent()
-
- trace = await agent.run(ctx)
-
- assert trace.isError is True
- assert trace.citations == []
diff --git a/hud/agents/tests/test_shared_eval_boundary.py b/hud/agents/tests/test_shared_eval_boundary.py
new file mode 100644
index 000000000..9c2c98f21
--- /dev/null
+++ b/hud/agents/tests/test_shared_eval_boundary.py
@@ -0,0 +1,239 @@
+from __future__ import annotations
+
+from typing import Any
+
+import pytest
+from mcp import types
+
+from hud.agents.tests.conftest import (
+ HarnessEvalContext,
+ RoutingHarnessTools,
+ ScriptedAgent,
+ mcp_tool,
+ text_prompt,
+ text_result,
+)
+from hud.types import AgentResponse, MCPToolCall, Trace
+
+
+@pytest.mark.asyncio
+async def test_eval_run_submits_final_content() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+ agent = ScriptedAgent([AgentResponse(content="answer", done=True)])
+
+ result = await ctx.run_agent(agent)
+
+ assert result.content == "answer"
+ assert ctx.submitted == "answer"
+
+
+@pytest.mark.asyncio
+async def test_eval_run_submits_citations_with_content() -> None:
+ citations = [{"type": "url", "source": "https://example.com"}]
+ ctx = HarnessEvalContext(prompt="Do the task")
+ agent = ScriptedAgent(
+ [AgentResponse(content="answer with sources", citations=citations, done=True)]
+ )
+
+ result = await ctx.run_agent(agent)
+
+ assert result.citations == citations
+ assert ctx.submitted == {"content": "answer with sources", "citations": citations}
+
+
+@pytest.mark.asyncio
+async def test_eval_run_does_not_submit_empty_content() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+ agent = ScriptedAgent([AgentResponse(content="", done=True)])
+
+ result = await ctx.run_agent(agent)
+
+ assert result.content == ""
+ assert ctx.submitted is None
+
+
+@pytest.mark.asyncio
+async def test_eval_run_records_error_without_submission() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+ agent = ScriptedAgent([AgentResponse(content="bad", isError=True, done=True)])
+
+ result = await ctx.run_agent(agent)
+
+ assert result.isError is True
+ assert isinstance(ctx.error, Exception)
+ assert str(ctx.error) == "bad"
+ assert ctx.submitted is None
+
+
+@pytest.mark.asyncio
+async def test_eval_run_requires_prompt_when_no_conversation_or_scenario_messages() -> None:
+ ctx = HarnessEvalContext(prompt="")
+ agent = ScriptedAgent([AgentResponse(content="unused", done=True)])
+
+ with pytest.raises(ValueError, match=r"ctx\.prompt is not set"):
+ await ctx.run_agent(agent)
+
+
+@pytest.mark.asyncio
+async def test_prompt_messages_prefer_scenario_messages_over_conversation_and_prompt() -> None:
+ scenario_message = text_prompt("scenario message", role="assistant")
+ ctx = HarnessEvalContext(prompt="fallback prompt")
+ ctx.conversation = [{"role": "user", "content": "conversation message"}]
+ ctx.set_scenario_messages([scenario_message])
+ agent = ScriptedAgent([AgentResponse(content="answer", done=True)])
+
+ await ctx.run_agent(agent)
+
+ assert agent.seen_messages[0] == [{"role": "assistant", "content": "scenario message"}]
+
+
+@pytest.mark.asyncio
+async def test_prompt_messages_use_conversation_before_prompt() -> None:
+ ctx = HarnessEvalContext(prompt="fallback prompt")
+ ctx.conversation = [
+ {"role": "assistant", "content": "previous"},
+ {"role": "user", "content": "next"},
+ ]
+ agent = ScriptedAgent([AgentResponse(content="answer", done=True)])
+
+ await ctx.run_agent(agent)
+
+ assert agent.seen_messages[0] == [
+ {"role": "assistant", "content": "previous"},
+ {"role": "user", "content": "next"},
+ ]
+
+
+@pytest.mark.asyncio
+async def test_eval_run_passes_citation_flag_to_agent() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+ ctx.enable_citations = True
+ agent = ScriptedAgent([AgentResponse(content="answer", done=True)])
+
+ await ctx.run_agent(agent)
+
+ assert agent.enable_citations is True
+
+
+@pytest.mark.asyncio
+async def test_eval_run_executes_environment_tool_and_submits_final_answer() -> None:
+ ctx = HarnessEvalContext(
+ prompt="Use a tool",
+ tools=[mcp_tool("lookup")],
+ tool_results={"lookup": text_result("looked up")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={"q": "hud"})]),
+ AgentResponse(content="answer", done=True),
+ ]
+ )
+
+ result = await ctx.run_agent(agent)
+
+ assert result.content == "answer"
+ assert ctx.submitted == "answer"
+ assert [(call.name, call.arguments) for call in ctx.environment.calls] == [
+ ("lookup", {"q": "hud"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_eval_tool_metadata_routes_native_provider_tool_to_environment_tool() -> None:
+ ctx = HarnessEvalContext(
+ prompt="Use shell",
+ tools=[mcp_tool("run_shell")],
+ metadata={"capabilities": {"shell": "run_shell"}},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="shell", arguments={"command": "pwd"})]),
+ AgentResponse(content="done", done=True),
+ ],
+ tools_factory=RoutingHarnessTools,
+ )
+
+ result = await ctx.run_agent(agent)
+
+ assert result.content == "done"
+ assert [(call.name, call.arguments) for call in ctx.environment.calls] == [
+ ("run_shell", {"command": "pwd"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_eval_run_passes_max_steps_to_agent_run() -> None:
+ ctx = HarnessEvalContext(prompt="Use a tool", tools=[mcp_tool("lookup")])
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]),
+ AgentResponse(content="too late", done=True),
+ ]
+ )
+
+ result = await ctx.run_agent(agent, max_steps=1)
+
+ assert result.content is None
+ assert ctx.submitted is None
+ assert [(call.name, call.arguments) for call in ctx.environment.calls] == [("lookup", {})]
+
+
+@pytest.mark.asyncio
+async def test_eval_run_records_agent_step_error_on_context() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+ agent = ScriptedAgent([RuntimeError("agent failed")])
+
+ result = await ctx.run_agent(agent)
+
+ assert result.isError is True
+ assert isinstance(ctx.error, Exception)
+ assert str(ctx.error) == "agent failed"
+ assert ctx.submitted is None
+
+
+@pytest.mark.asyncio
+async def test_submit_result_error_prefers_info_error_message() -> None:
+ ctx = HarnessEvalContext(prompt="Do the task")
+
+ result = Trace(isError=True, content="fallback", info={"error": "specific"})
+
+ await ctx.submit_result(result)
+
+ assert isinstance(ctx.error, Exception)
+ assert str(ctx.error) == "specific"
+
+
+def test_tool_metadata_accepts_legacy_capabilities_shape() -> None:
+ ctx = HarnessEvalContext(
+ prompt="Do the task",
+ metadata={"capabilities": {"computer": "computer"}},
+ )
+
+ metadata = ctx.tool_metadata_for_run()
+
+ assert metadata == {"capabilities": {"computer": "computer"}}
+
+
+def test_tool_metadata_prefers_environment_capabilities_shape() -> None:
+ environment_capabilities: dict[str, Any] = {"capabilities": {"computer": {"tool": "computer"}}}
+ ctx = HarnessEvalContext(
+ prompt="Do the task",
+ metadata={"environment_capabilities": environment_capabilities},
+ )
+
+ metadata = ctx.tool_metadata_for_run()
+
+ assert metadata is environment_capabilities
+
+
+def test_prompt_falls_back_to_plain_user_message() -> None:
+ ctx = HarnessEvalContext(prompt="hello")
+
+ messages = ctx.prompt_messages()
+
+ assert messages == [
+ types.PromptMessage(
+ role="user",
+ content=types.TextContent(type="text", text="hello"),
+ )
+ ]
diff --git a/hud/agents/tests/test_shared_run_loop.py b/hud/agents/tests/test_shared_run_loop.py
new file mode 100644
index 000000000..d64bb4e62
--- /dev/null
+++ b/hud/agents/tests/test_shared_run_loop.py
@@ -0,0 +1,295 @@
+from __future__ import annotations
+
+import asyncio
+
+import pytest
+
+from hud.agents.base import AgentContext
+from hud.agents.tests.conftest import (
+ HarnessConfig,
+ RecordingToolEnvironment,
+ ScriptedAgent,
+ mcp_tool,
+ text_prompt,
+ text_result,
+)
+from hud.types import AgentResponse, MCPToolCall
+
+
+@pytest.mark.asyncio
+async def test_run_returns_final_response_without_tools() -> None:
+ agent = ScriptedAgent([AgentResponse(content="done", done=True)])
+
+ result = await agent.run(AgentContext(messages=[text_prompt("do it")]))
+
+ assert result.done is True
+ assert result.isError is False
+ assert result.content == "done"
+ assert agent.seen_messages == [[{"role": "user", "content": "do it"}]]
+
+
+@pytest.mark.asyncio
+async def test_run_executes_tool_call_and_continues_with_tool_result() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("found it")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(
+ tool_calls=[MCPToolCall(name="lookup", arguments={"query": "thing"})],
+ done=False,
+ ),
+ AgentResponse(content="answer", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("find thing")], tool_client=environment.client)
+ )
+
+ assert result.content == "answer"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "thing"})
+ ]
+ assert agent.seen_messages[1][-1] == {
+ "role": "tool",
+ "name": "lookup",
+ "content": "found it",
+ "is_error": False,
+ }
+
+
+@pytest.mark.asyncio
+async def test_run_supports_multiple_tool_steps_before_final_answer() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("first"), mcp_tool("second")],
+ results={"first": text_result("one"), "second": text_result("two")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="first", arguments={})]),
+ AgentResponse(tool_calls=[MCPToolCall(name="second", arguments={"n": 2})]),
+ AgentResponse(content="finished", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("go")], tool_client=environment.client)
+ )
+
+ assert result.content == "finished"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("first", {}),
+ ("second", {"n": 2}),
+ ]
+ assert len(agent.seen_messages) == 3
+
+
+@pytest.mark.asyncio
+async def test_run_preserves_same_turn_tool_call_order() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("first"), mcp_tool("second")],
+ results={"first": text_result("one"), "second": text_result("two")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(
+ tool_calls=[
+ MCPToolCall(name="first", arguments={"order": 1}),
+ MCPToolCall(name="second", arguments={"order": 2}),
+ ]
+ ),
+ AgentResponse(content="finished", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("call both")], tool_client=environment.client)
+ )
+
+ assert result.content == "finished"
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("first", {"order": 1}),
+ ("second", {"order": 2}),
+ ]
+ assert agent.seen_messages[1][-2:] == [
+ {"role": "tool", "name": "first", "content": "one", "is_error": False},
+ {"role": "tool", "name": "second", "content": "two", "is_error": False},
+ ]
+
+
+@pytest.mark.asyncio
+async def test_unlimited_max_steps_runs_until_final_answer() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("loop")])
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="loop", arguments={"step": 1})]),
+ AgentResponse(tool_calls=[MCPToolCall(name="loop", arguments={"step": 2})]),
+ AgentResponse(content="done", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("loop")], tool_client=environment.client),
+ max_steps=-1,
+ )
+
+ assert result.content == "done"
+ assert [call.arguments for call in environment.calls] == [{"step": 1}, {"step": 2}]
+
+
+@pytest.mark.asyncio
+async def test_tool_timeout_stops_run_with_error_trace() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("slow")],
+ results={"slow": TimeoutError("too slow")},
+ )
+ agent = ScriptedAgent([AgentResponse(tool_calls=[MCPToolCall(name="slow", arguments={})])])
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("try slow")], tool_client=environment.client)
+ )
+
+ assert result.isError is True
+ assert result.info["error"] == "too slow"
+ assert [(call.name, call.arguments) for call in environment.calls] == [("slow", {})]
+
+
+@pytest.mark.asyncio
+async def test_tool_errors_are_returned_to_the_model_as_error_results() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": RuntimeError("backend exploded")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]),
+ AgentResponse(content="recovered", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("try")], tool_client=environment.client)
+ )
+
+ assert result.content == "recovered"
+ assert agent.seen_messages[1][-1]["is_error"] is True
+ assert agent.seen_messages[1][-1]["content"] == "backend exploded"
+
+
+@pytest.mark.asyncio
+async def test_missing_tool_client_turns_tool_call_into_error_trace() -> None:
+ agent = ScriptedAgent([AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})])])
+
+ result = await agent.run(AgentContext(messages=[text_prompt("call lookup")]))
+
+ assert result.isError is True
+ assert result.info["error"] == "call_tool callback is required to execute tool calls"
+
+
+@pytest.mark.asyncio
+async def test_max_steps_caps_tool_loop() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("lookup")])
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]),
+ AgentResponse(content="should not be reached", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("loop")], tool_client=environment.client),
+ max_steps=1,
+ )
+
+ assert result.done is True
+ assert result.content is None
+ assert len(environment.calls) == 1
+ assert len(agent.seen_messages) == 1
+
+
+@pytest.mark.asyncio
+async def test_auto_respond_can_continue_after_a_done_response(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ calls: list[str | None] = []
+
+ async def continue_once(content: str | None, *, enabled: bool) -> object:
+ calls.append(content)
+ assert enabled is True
+ if len(calls) > 1:
+ return None
+ return text_prompt("continue")
+
+ monkeypatch.setattr("hud.agents.base.auto_respond", continue_once)
+ agent = ScriptedAgent(
+ [
+ AgentResponse(content="need input", done=True),
+ AgentResponse(content="final", done=True),
+ ],
+ config=HarnessConfig(auto_respond=True),
+ )
+
+ result = await agent.run(AgentContext(messages=[text_prompt("start")]))
+
+ assert result.content == "final"
+ assert calls == ["need input", "final"]
+ assert agent.seen_messages[1][-1] == {"role": "user", "content": "continue"}
+
+
+@pytest.mark.asyncio
+async def test_model_step_exception_returns_error_trace() -> None:
+ agent = ScriptedAgent([RuntimeError("model failed")])
+
+ result = await agent.run(AgentContext(messages=[text_prompt("start")]))
+
+ assert result.done is True
+ assert result.isError is True
+ assert result.content == "model failed"
+
+
+@pytest.mark.asyncio
+async def test_keyboard_interrupt_returns_interrupted_trace() -> None:
+ agent = ScriptedAgent([KeyboardInterrupt()])
+
+ result = await agent.run(AgentContext(messages=[text_prompt("start")]))
+
+ assert result.isError is True
+ assert result.content == "Interrupted by user"
+ assert result.info["error"] == "Interrupted by user"
+
+
+@pytest.mark.asyncio
+async def test_cancelled_run_returns_cancelled_trace() -> None:
+ agent = ScriptedAgent([asyncio.CancelledError()])
+
+ result = await agent.run(AgentContext(messages=[text_prompt("start")]))
+
+ assert result.isError is True
+ assert result.content == "Cancelled"
+ assert result.info["error"] == "Cancelled"
+
+
+@pytest.mark.asyncio
+async def test_trace_messages_include_provider_history_before_stop() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("found")},
+ )
+ agent = ScriptedAgent(
+ [
+ AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]),
+ AgentResponse(content="done", done=True),
+ ]
+ )
+
+ result = await agent.run(
+ AgentContext(messages=[text_prompt("start")], tool_client=environment.client)
+ )
+
+ assert result.content == "done"
+ assert result.messages == [
+ {"role": "user", "content": "start"},
+ {"role": "tool", "name": "lookup", "content": "found", "is_error": False},
+ ]
diff --git a/hud/agents/tests/test_shared_tool_registry.py b/hud/agents/tests/test_shared_tool_registry.py
new file mode 100644
index 000000000..760af5e7b
--- /dev/null
+++ b/hud/agents/tests/test_shared_tool_registry.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+import pytest
+
+from hud.agents.tests.conftest import (
+ RecordingToolEnvironment,
+ RoutingHarnessTools,
+ mcp_tool,
+ text_result,
+)
+from hud.agents.tools.capabilities import discover_environment_capabilities
+from hud.types import MCPToolCall
+
+if TYPE_CHECKING:
+ from hud.agents.tools import ToolMetadata
+
+
+@pytest.mark.asyncio
+async def test_generic_tool_call_routes_to_matching_environment_tool() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": text_result("found")},
+ )
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(model="test-model", tools=environment.tools)
+
+ outputs = await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="lookup", arguments={"query": "hud"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("lookup", {"query": "hud"})
+ ]
+ assert outputs == [{"role": "tool", "name": "lookup", "content": "found", "is_error": False}]
+
+
+@pytest.mark.asyncio
+async def test_capability_metadata_routes_provider_tool_to_environment_tool() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("run_shell")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(
+ model="test-model",
+ tools=environment.tools,
+ tool_metadata={"capabilities": {"shell": "run_shell"}},
+ )
+
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="shell", arguments={"command": "pwd"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("run_shell", {"command": "pwd"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_name_fallback_routes_native_tool_when_metadata_is_absent() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("bash")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(model="test-model", tools=environment.tools)
+
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="shell", arguments={"command": "echo hi"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("bash", {"command": "echo hi"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_grouped_capability_metadata_routes_to_the_selected_environment_tool() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("read"), mcp_tool("grep")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(
+ model="test-model",
+ tools=environment.tools,
+ tool_metadata={"capabilities": {"filesystem": {"tools": {"read": "read", "grep": "grep"}}}},
+ )
+
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="read_file", arguments={"path": "README.md"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("read", {"path": "README.md"})
+ ]
+
+
+@pytest.mark.asyncio
+async def test_native_tool_takes_precedence_over_generic_tool_with_same_environment_name() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("bash"), mcp_tool("lookup")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(model="test-model", tools=environment.tools)
+
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="shell", arguments={"command": "whoami"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("bash", {"command": "whoami"})
+ ]
+ with pytest.raises(KeyError):
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="bash", arguments={"command": "whoami"}),
+ )
+
+
+@pytest.mark.asyncio
+async def test_unknown_provider_tool_fails_before_environment_execution() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("lookup")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(model="test-model", tools=environment.tools)
+
+ with pytest.raises(KeyError):
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="missing", arguments={}),
+ )
+
+ assert environment.calls == []
+
+
+@pytest.mark.asyncio
+async def test_timeout_error_propagates_to_run_loop_boundary() -> None:
+ environment = RecordingToolEnvironment(
+ [mcp_tool("lookup")],
+ results={"lookup": TimeoutError("tool timed out")},
+ )
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(model="test-model", tools=environment.tools)
+
+ with pytest.raises(TimeoutError, match="tool timed out"):
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="lookup", arguments={}),
+ )
+
+
+def test_invalid_capability_metadata_fails_at_the_boundary() -> None:
+ with pytest.raises(ValueError, match="Invalid capability metadata"):
+ discover_environment_capabilities(
+ [mcp_tool("lookup")],
+ tool_metadata=cast(
+ "ToolMetadata",
+ {"capabilities": {"lookup": {"unexpected": "shape"}}},
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_stale_capability_metadata_falls_back_to_available_tool_names() -> None:
+ environment = RecordingToolEnvironment([mcp_tool("bash")])
+ agent_tools = RoutingHarnessTools()
+ agent_tools.prepare(
+ model="test-model",
+ tools=environment.tools,
+ tool_metadata={"capabilities": {"shell": "missing_shell"}},
+ )
+
+ await agent_tools.execute(
+ environment.call_tool,
+ MCPToolCall(name="shell", arguments={"command": "pwd"}),
+ )
+
+ assert [(call.name, call.arguments) for call in environment.calls] == [
+ ("bash", {"command": "pwd"})
+ ]
diff --git a/hud/agents/tools/__init__.py b/hud/agents/tools/__init__.py
index 97f4d670d..116387e86 100644
--- a/hud/agents/tools/__init__.py
+++ b/hud/agents/tools/__init__.py
@@ -2,30 +2,28 @@
from __future__ import annotations
-from .base import AgentTool, AgentToolSpec, CallTool, call_agent_tools, call_tool
+from .base import (
+ AgentTool,
+ AgentTools,
+ AgentToolSpec,
+)
from .capabilities import (
+ CapabilityEntry,
EnvironmentCapability,
GroupedCapabilityMixin,
- capabilities_metadata_from_context,
+ ToolMetadata,
discover_environment_capabilities,
)
-from .hosted import (
- HostedTool,
- select_hosted_tools,
-)
-from .registry import AgentToolRegistry
+from .hosted import HostedTool
__all__ = [
"AgentTool",
- "AgentToolRegistry",
"AgentToolSpec",
- "CallTool",
+ "AgentTools",
+ "CapabilityEntry",
"EnvironmentCapability",
"GroupedCapabilityMixin",
"HostedTool",
- "call_agent_tools",
- "call_tool",
- "capabilities_metadata_from_context",
+ "ToolMetadata",
"discover_environment_capabilities",
- "select_hosted_tools",
]
diff --git a/hud/agents/tools/base.py b/hud/agents/tools/base.py
index 2ba5ea806..435027c23 100644
--- a/hud/agents/tools/base.py
+++ b/hud/agents/tools/base.py
@@ -3,19 +3,37 @@
from __future__ import annotations
import fnmatch
+import logging
from abc import ABC, abstractmethod
-from collections.abc import Awaitable, Callable, Mapping
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, TypeVar
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, TypeVar, cast
+import mcp.types as types
+
+from hud.agents.tools.capabilities import discover_environment_capabilities
from hud.types import MCPToolCall, MCPToolResult
if TYPE_CHECKING:
- from hud.agents.base import MCPAgent
- from hud.agents.tools.capabilities import EnvironmentCapability
+ from collections.abc import Mapping
+
+ from hud.agents.tools.capabilities import EnvironmentCapability, ToolMetadata
+ from hud.agents.tools.hosted import HostedTool
+AgentToolParamT_co = TypeVar("AgentToolParamT_co", covariant=True)
ToolParamT = TypeVar("ToolParamT")
+AgentToolT = TypeVar("AgentToolT", bound="AgentTool[object]")
CallTool = Callable[[MCPToolCall], Awaitable[MCPToolResult]]
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class ToolClient:
+ """MCP tools and execution hook available for one agent run."""
+
+ tools: list[types.Tool] = field(default_factory=list[types.Tool])
+ tool_handler: CallTool | None = None
+ tool_metadata: ToolMetadata | None = None
@dataclass(frozen=True)
@@ -24,20 +42,21 @@ class AgentToolSpec:
api_type: str
api_name: str
- beta: str | None = None
supported_models: tuple[str, ...] | None = None
def supports_model(self, model: str | None) -> bool:
- if not self.supported_models or not model or model == "unknown":
+ if not self.supported_models:
return True
+ if not model or model == "unknown":
+ return False
model_lower = model.lower()
return any(
fnmatch.fnmatch(model_lower, pattern.lower()) for pattern in self.supported_models
)
-class AgentTool(ABC, Generic[ToolParamT]):
- """Provider-facing tool backed by one environment tool."""
+class AgentTool(ABC, Generic[AgentToolParamT_co]):
+ """Provider-facing tool owned by an agent harness."""
name: ClassVar[str]
capability: ClassVar[str]
@@ -46,79 +65,182 @@ def __init__(self, *, env_tool_name: str, spec: AgentToolSpec) -> None:
self.env_tool_name = env_tool_name
self.spec = spec
+ @property
+ def provider_name(self) -> str:
+ return self.name
+
+ @classmethod
+ def env_tool_name_for_capability(cls, capability: EnvironmentCapability) -> str | None:
+ return capability.tool_name
+
@classmethod
def from_capability(
cls,
capability: EnvironmentCapability,
- spec: AgentToolSpec,
model: str,
- ) -> Self:
- del model
- return cls(env_tool_name=capability.tool_name, spec=spec)
+ ) -> Self | None:
+ spec = cls.default_spec(model)
+ env_tool_name = cls.env_tool_name_for_capability(capability)
+ if spec is None or env_tool_name is None:
+ return None
+ return cls(env_tool_name=env_tool_name, spec=spec)
@classmethod
def default_spec(cls, model: str) -> AgentToolSpec | None:
"""Return the provider spec this agent should use for this capability."""
- del model
return None
- @property
- def required_beta(self) -> str | None:
- return self.spec.beta
-
- async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
- """Execute by forwarding to the backing environment tool."""
- return await call_tool(caller, self.env_tool_name, arguments)
-
- @abstractmethod
- def to_params(self) -> ToolParamT: ...
+ @classmethod
+ def from_tool(cls, tool: types.Tool) -> Self | None:
+ """Build a provider tool for a generic environment tool."""
+ del tool
+ return None
+ async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult:
+ """Execute an environment-backed tool by forwarding to its MCP tool."""
+ return await call_tool(MCPToolCall(name=self.env_tool_name, arguments=arguments))
-async def call_tool(
- caller: CallTool,
- env_tool_name: str,
- arguments: dict[str, Any],
-) -> MCPToolResult:
- result = await caller(MCPToolCall(name=env_tool_name, arguments=arguments))
- return MCPToolResult(content=result.content, isError=result.isError)
+ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> Any | None:
+ """Format a single tool result for the provider continuation turn."""
+ del result
+ logger.warning("Tool '%s' does not implement result formatting.", call.name)
+ return None
+ @abstractmethod
+ def to_params(self) -> AgentToolParamT_co: ...
-async def call_agent_tools(
- agent: MCPAgent,
- agent_tools: Mapping[str, AgentTool[Any]],
- tool_call: MCPToolCall | list[MCPToolCall] | None = None,
-) -> list[MCPToolResult]:
- """Route provider-owned tool calls through adapters, otherwise through MCP."""
- import mcp.types as types
- from hud.agents.base import MCPAgent
+class AgentTools(dict[str, AgentToolT], Generic[AgentToolT, ToolParamT]):
+ """Prepared tool state owned by a single agent run."""
- if tool_call is None:
- return []
- tool_calls = [tool_call] if isinstance(tool_call, MCPToolCall) else tool_call
+ native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = ()
+ function_tool_class: ClassVar[type[AgentTool[object]] | None] = None
+ name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {}
- async def call_env_tool(call: MCPToolCall) -> MCPToolResult:
- return (await MCPAgent.call_tools(agent, call))[0]
+ def __init__(self) -> None:
+ super().__init__()
+ self.params: list[ToolParamT] = []
+ self.name_map: dict[str, str] = {}
+ self.hosted_tools: list[HostedTool[object]] = []
- results: list[MCPToolResult] = []
- for tc in tool_calls:
- agent_tool = agent_tools.get(tc.name)
- if agent_tool is None:
- results.extend(await MCPAgent.call_tools(agent, tc))
- continue
+ def select_tools(
+ self,
+ tools: list[types.Tool],
+ model: str,
+ *,
+ tool_metadata: ToolMetadata | None = None,
+ ) -> tuple[list[AgentToolT], list[types.Tool]]:
+ """Split MCP tools into provider-owned and user-defined tools."""
+ logger.info("Discovered %s tools: %s", len(tools), ", ".join(tool.name for tool in tools))
+
+ capabilities = discover_environment_capabilities(
+ tools,
+ tool_metadata=tool_metadata,
+ name_fallbacks=self.name_fallbacks,
+ )
+ agent_tools: list[AgentToolT] = []
+ for capability in capabilities.values():
+ for raw_tool_cls in self.native_tool_classes:
+ tool_cls = cast("type[AgentToolT]", raw_tool_cls)
+ if tool_cls.capability != capability.name:
+ continue
+ tool = tool_cls.from_capability(capability, model)
+ if tool is not None:
+ agent_tools.append(tool)
+ agent_tool_names = {tool.env_tool_name for tool in agent_tools}
+ user_tools = [tool for tool in tools if tool.name not in agent_tool_names]
+ return agent_tools, user_tools
+
+ def generic_tool(
+ self,
+ tool: types.Tool,
+ ) -> ToolParamT | None:
+ """Convert an environment MCP tool into provider params."""
+ del tool
+ return None
- try:
+ def prepare(
+ self,
+ *,
+ model: str,
+ tools: list[types.Tool],
+ hosted_tools: list[HostedTool[object]] | None = None,
+ tool_metadata: ToolMetadata | None = None,
+ ) -> None:
+ """Prepare a generic provider tool map for an agent run."""
+ provider_tools, user_tools = self.select_tools(
+ tools,
+ model,
+ tool_metadata=tool_metadata,
+ )
+ tools_by_name = {tool.provider_name: tool for tool in provider_tools}
+ installed_names = set(tools_by_name)
+ self.update(tools_by_name)
+ self.params.extend(cast("ToolParamT", tool.to_params()) for tool in provider_tools)
+ self.name_map.update({name: name for name in tools_by_name})
+
+ selected_hosted_tools: list[HostedTool[object]] = []
+ for tool in hosted_tools or []:
+ if not tool.supports_model(model):
+ continue
+ selected_hosted_tools.append(tool)
+ self.params.append(cast("ToolParamT", tool.to_params()))
+ self.hosted_tools = selected_hosted_tools
+
+ for tool in user_tools:
+ if self.function_tool_class is not None:
+ function_tool_cls = cast("type[AgentToolT]", self.function_tool_class)
+ agent_tool = function_tool_cls.from_tool(tool)
+ if agent_tool is None:
+ continue
+ self[agent_tool.provider_name] = agent_tool
+ installed_names.add(agent_tool.provider_name)
+ self.name_map[tool.name] = agent_tool.provider_name
+ self.params.append(cast("ToolParamT", agent_tool.to_params()))
+ continue
+ generic_tool = self.generic_tool(tool)
+ if generic_tool is None:
+ continue
+ installed_names.add(tool.name)
+ self.name_map[tool.name] = tool.name
+ self.params.append(generic_tool)
+
+ tool_names = sorted(installed_names)
+ logger.info("Agent initialized with %s tools: %s", len(tool_names), ", ".join(tool_names))
+
+ async def execute(
+ self,
+ call_tool: CallTool | None,
+ tool_call: MCPToolCall | list[MCPToolCall] | None = None,
+ ) -> list[Any]:
+ if tool_call is None:
+ return []
+
+ if call_tool is None:
+ raise ValueError("call_tool callback is required to execute tool calls")
+
+ outputs: list[Any] = []
+ tool_calls = [tool_call] if isinstance(tool_call, MCPToolCall) else tool_call
+ for tc in tool_calls:
+ agent_tool = self[tc.name]
arguments = tc.arguments if isinstance(tc.arguments, dict) else {}
- results.append(await agent_tool.execute(call_env_tool, arguments))
- except Exception as exc:
- agent.console.error_log(f"Agent tool execution failed: {exc}")
- results.append(
- MCPToolResult(
+ try:
+ result = await agent_tool.execute(call_tool, arguments)
+ except TimeoutError:
+ raise
+ except Exception as exc:
+ logger.exception("Tool execution failed")
+ result = MCPToolResult(
content=[types.TextContent(type="text", text=str(exc))],
isError=True,
)
- )
- return results
+ output = agent_tool.format_result(tc, result)
+ if output is None:
+ continue
+ if isinstance(output, list):
+ outputs.extend(cast("list[Any]", output))
+ else:
+ outputs.append(output)
-__all__ = ["AgentTool", "AgentToolSpec", "CallTool", "call_agent_tools", "call_tool"]
+ return outputs
diff --git a/hud/agents/tools/capabilities.py b/hud/agents/tools/capabilities.py
index 2dc24d8fc..5c8282e7f 100644
--- a/hud/agents/tools/capabilities.py
+++ b/hud/agents/tools/capabilities.py
@@ -2,131 +2,108 @@
from __future__ import annotations
-from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, ClassVar, Self
+from typing import TYPE_CHECKING, ClassVar, TypedDict, cast
if TYPE_CHECKING:
+ from collections.abc import Mapping
+
from mcp import types as mcp_types
- from hud.agents.tools.base import AgentToolSpec
+ from hud.types import JsonObject, JsonValue
+else:
+ JsonObject = dict[str, object]
+ JsonValue = object
-@dataclass(frozen=True)
-class EnvironmentCapability:
- """A normalized environment capability bound to one or more MCP tools."""
- name: str
+class CapabilityEntry(TypedDict, total=False):
+ tool: str
tool_name: str
- tool: mcp_types.Tool
- metadata: dict[str, Any] = field(default_factory=dict)
+ tools: dict[str, str]
-def capabilities_metadata_from_context(ctx: Any) -> dict[str, Any] | None:
- """Extract an optional env-level capability descriptor from a context."""
- if ctx is None:
- return None
+class ToolMetadata(TypedDict, total=False):
+ capabilities: dict[str, str | CapabilityEntry]
- direct = getattr(ctx, "environment_capabilities", None)
- if isinstance(direct, dict):
- return direct
- direct = getattr(ctx, "capabilities", None)
- if isinstance(direct, dict):
- return {"capabilities": direct}
-
- metadata = getattr(ctx, "metadata", None)
- if isinstance(metadata, dict):
- for key in ("environment_capabilities", "capabilities"):
- value = metadata.get(key)
- if isinstance(value, dict):
- return value if key == "environment_capabilities" else {"capabilities": value}
+class EnvironmentCapability:
+ """A normalized environment capability bound to one or more MCP tools."""
- return None
+ def __init__(
+ self,
+ *,
+ name: str,
+ tool_name: str,
+ tool: mcp_types.Tool,
+ metadata: JsonObject | None = None,
+ ) -> None:
+ self.name = name
+ self.tool_name = tool_name
+ self.tool = tool
+ self.metadata: JsonObject = metadata or {}
def discover_environment_capabilities(
tools: list[mcp_types.Tool],
*,
- env_metadata: dict[str, Any] | None = None,
- name_fallbacks: dict[str, tuple[str, ...]] | None = None,
+ tool_metadata: ToolMetadata | None = None,
+ name_fallbacks: Mapping[str, tuple[str, ...]] | None = None,
) -> dict[str, EnvironmentCapability]:
"""Build a normalized capability map from env metadata and tool inventory."""
tool_by_name = {tool.name: tool for tool in tools}
capabilities: dict[str, EnvironmentCapability] = {}
- _add_env_capabilities(capabilities, tool_by_name, env_metadata)
- _add_name_fallback_capabilities(capabilities, tool_by_name, name_fallbacks or {})
-
- return capabilities
-
-
-def _add_env_capabilities(
- capabilities: dict[str, EnvironmentCapability],
- tool_by_name: dict[str, mcp_types.Tool],
- env_metadata: dict[str, Any] | None,
-) -> None:
- if not env_metadata:
- return
-
- raw = env_metadata.get("capabilities", env_metadata)
- if not isinstance(raw, dict):
- return
-
- for name, config in raw.items():
- if not isinstance(name, str) or name in capabilities:
- continue
- tool_name: str | None = None
- metadata: dict[str, Any] = {}
- if isinstance(config, str):
- tool_name = config
- elif isinstance(config, dict):
- raw_tool = config.get("tool") or config.get("tool_name")
- if isinstance(raw_tool, str):
- tool_name = raw_tool
- metadata = dict(config)
- else:
- raw_tools = config.get("tools")
- if isinstance(raw_tools, dict):
- tool_names = {
- str(key): value
- for key, value in raw_tools.items()
- if isinstance(value, str) and value in tool_by_name
- }
- if tool_names:
- tool_name = next(iter(tool_names.values()))
- metadata = {**config, "tools": tool_names}
- if tool_name is None:
- continue
- tool = tool_by_name.get(tool_name)
- if tool is None:
+ metadata = tool_metadata or {}
+ raw_capabilities = cast(
+ "dict[str, str | CapabilityEntry]",
+ metadata.get("capabilities", metadata),
+ )
+ for name, config in raw_capabilities.items():
+ match config:
+ case str() as tool_name:
+ capability_metadata: JsonObject = {}
+ case {"tool": str() as tool_name}:
+ capability_metadata = {"tool": tool_name}
+ case {"tool_name": str() as tool_name}:
+ capability_metadata = {"tool_name": tool_name}
+ case {"tools": grouped_tools}:
+ tool_names: dict[str, JsonValue] = {
+ str(alias): env_tool_name
+ for alias, env_tool_name in grouped_tools.items()
+ if env_tool_name in tool_by_name
+ }
+ if not tool_names:
+ continue
+ tool_name = str(next(iter(tool_names.values())))
+ capability_metadata = {"tools": tool_names}
+ case _:
+ raise ValueError(f"Invalid capability metadata for {name!r}: {config!r}")
+
+ if tool_name not in tool_by_name:
continue
+
capabilities[name] = EnvironmentCapability(
name=name,
- tool_name=tool.name,
- tool=tool,
- metadata=metadata,
+ tool_name=tool_name,
+ tool=tool_by_name[tool_name],
+ metadata=capability_metadata,
)
-
-def _add_name_fallback_capabilities(
- capabilities: dict[str, EnvironmentCapability],
- tool_by_name: dict[str, mcp_types.Tool],
- name_fallbacks: dict[str, tuple[str, ...]],
-) -> None:
- for capability, names in name_fallbacks.items():
+ for capability, names in (name_fallbacks or {}).items():
if capability in capabilities:
continue
matched_tool_names = [name for name in names if name in tool_by_name]
- tool_name = matched_tool_names[0] if matched_tool_names else None
- if tool_name is None:
+ if not matched_tool_names:
continue
- tool = tool_by_name[tool_name]
+
+ tool = tool_by_name[matched_tool_names[0]]
capabilities[capability] = EnvironmentCapability(
name=capability,
tool_name=tool.name,
tool=tool,
metadata={"tools": {name: name for name in matched_tool_names}},
)
+ return capabilities
class GroupedCapabilityMixin:
@@ -134,37 +111,14 @@ class GroupedCapabilityMixin:
env_tool_names: ClassVar[tuple[str, ...]]
- if TYPE_CHECKING:
-
- def __init__(self, *, env_tool_name: str, spec: AgentToolSpec) -> None: ...
-
@classmethod
def env_tool_name_for_capability(cls, capability: EnvironmentCapability) -> str | None:
- tools = capability.metadata.get("tools")
- if isinstance(tools, dict):
- return next(
- (tools[name] for name in cls.env_tool_names if isinstance(tools.get(name), str)),
- None,
- )
+ tools_obj = capability.metadata.get("tools")
+ if isinstance(tools_obj, dict):
+ tools_map = cast("dict[str, object]", tools_obj)
+ for name in cls.env_tool_names:
+ if env_tool_name := tools_map.get(name):
+ return str(env_tool_name)
if capability.tool_name in cls.env_tool_names:
return capability.tool_name
return None
-
- @classmethod
- def from_capability(
- cls,
- capability: EnvironmentCapability,
- spec: AgentToolSpec,
- model: str,
- ) -> Self:
- del model
- env_tool_name = cls.env_tool_name_for_capability(capability) or capability.tool_name
- return cls(env_tool_name=env_tool_name, spec=spec)
-
-
-__all__ = [
- "EnvironmentCapability",
- "GroupedCapabilityMixin",
- "capabilities_metadata_from_context",
- "discover_environment_capabilities",
-]
diff --git a/hud/agents/tools/computer.py b/hud/agents/tools/computer.py
new file mode 100644
index 000000000..b8e94c6c6
--- /dev/null
+++ b/hud/agents/tools/computer.py
@@ -0,0 +1,104 @@
+"""Shared helpers for agent-side computer tools."""
+
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable, Mapping
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, cast
+
+from mcp.types import ImageContent, TextContent
+
+from hud.types import MCPToolCall, MCPToolResult
+
+if TYPE_CHECKING:
+ from mcp import types as mcp_types
+
+CallTool = Callable[[MCPToolCall], Awaitable[MCPToolResult]]
+
+
+@dataclass(frozen=True)
+class ComputerToolInfo:
+ """Computer MCP tool metadata needed by provider adapters."""
+
+ display_width: int
+ display_height: int
+ coordinate_space: int | None
+
+
+def computer_tool_info(
+ tool: mcp_types.Tool,
+ *,
+ default_width: int,
+ default_height: int,
+) -> ComputerToolInfo:
+ """Resolve the computer contract advertised by the MCP tool."""
+ meta = cast("Mapping[str, object]", tool.meta or {})
+ resolution = meta.get("resolution")
+ display_width = default_width
+ display_height = default_height
+
+ if isinstance(resolution, Mapping):
+ resolution = cast("Mapping[str, object]", resolution)
+ width = resolution.get("width")
+ height = resolution.get("height")
+ if type(width) is int:
+ display_width = width
+ if type(height) is int:
+ display_height = height
+
+ coordinate_space_raw = meta.get("coordinate_space")
+ coordinate_space = coordinate_space_raw if type(coordinate_space_raw) is int else None
+
+ return ComputerToolInfo(
+ display_width=display_width,
+ display_height=display_height,
+ coordinate_space=coordinate_space,
+ )
+
+
+def computer_error_result(message: str) -> MCPToolResult:
+ return MCPToolResult(content=[TextContent(type="text", text=message)], isError=True)
+
+
+def result_has_image(result: MCPToolResult) -> bool:
+ return any(isinstance(block, ImageContent) for block in result.content)
+
+
+def first_image_data(result: MCPToolResult) -> str | None:
+ for block in result.content:
+ if isinstance(block, ImageContent):
+ return block.data
+ return None
+
+
+def last_image_data(result: MCPToolResult) -> str | None:
+ for block in reversed(result.content):
+ if isinstance(block, ImageContent):
+ return block.data
+ return None
+
+
+async def execute_computer_calls(
+ call_tool: CallTool,
+ *,
+ env_tool_name: str,
+ calls: list[dict[str, Any]],
+ ensure_screenshot: bool,
+) -> MCPToolResult:
+ result = MCPToolResult(content=[], isError=False)
+ for arguments in calls:
+ result = await call_tool(MCPToolCall(name=env_tool_name, arguments=arguments))
+ if result.isError:
+ return result
+
+ if ensure_screenshot and not result_has_image(result):
+ screenshot = await call_tool(
+ MCPToolCall(name=env_tool_name, arguments={"action": "screenshot"})
+ )
+ if not screenshot.isError and screenshot.content:
+ return MCPToolResult(
+ content=[*result.content, *screenshot.content],
+ isError=result.isError,
+ )
+
+ return result
diff --git a/hud/agents/tools/hosted.py b/hud/agents/tools/hosted.py
index 160bcab98..e86c3934d 100644
--- a/hud/agents/tools/hosted.py
+++ b/hud/agents/tools/hosted.py
@@ -2,49 +2,30 @@
from __future__ import annotations
+import fnmatch
+from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Any, Generic, TypeVar
+from typing import Generic, TypeVar
-from .base import AgentToolSpec
-
-HostedToolParamT = TypeVar("HostedToolParamT")
-HostedToolT = TypeVar("HostedToolT", bound="HostedTool[Any]")
+HostedToolParamT_co = TypeVar("HostedToolParamT_co", covariant=True)
@dataclass(frozen=True, kw_only=True)
-class HostedTool(Generic[HostedToolParamT]):
+class HostedTool(ABC, Generic[HostedToolParamT_co]):
"""Provider-side tool activated only through explicit agent config."""
supported_models: tuple[str, ...] | None = None
def supports_model(self, model: str | None) -> bool:
- spec = AgentToolSpec(
- api_type="hosted",
- api_name=self.__class__.__name__,
- supported_models=self.supported_models,
+ if not self.supported_models:
+ return True
+ if not model or model == "unknown":
+ return False
+ model_lower = model.lower()
+ return any(
+ fnmatch.fnmatch(model_lower, pattern.lower()) for pattern in self.supported_models
)
- return spec.supports_model(model)
- def to_params(self) -> HostedToolParamT:
+ @abstractmethod
+ def to_params(self) -> HostedToolParamT_co:
raise NotImplementedError
-
-
-def select_hosted_tools(
- hosted_tools: list[Any],
- *,
- tool_type: type[HostedToolT],
- model: str,
-) -> list[HostedToolT]:
- """Select explicitly configured hosted tools for one provider/model."""
- selected: list[HostedToolT] = []
- for hosted_tool in hosted_tools:
- if not isinstance(hosted_tool, tool_type) or not hosted_tool.supports_model(model):
- continue
- selected.append(hosted_tool)
- return selected
-
-
-__all__ = [
- "HostedTool",
- "select_hosted_tools",
-]
diff --git a/hud/agents/tools/registry.py b/hud/agents/tools/registry.py
deleted file mode 100644
index 2de27c52c..000000000
--- a/hud/agents/tools/registry.py
+++ /dev/null
@@ -1,57 +0,0 @@
-"""Registry support for agent-owned tools."""
-
-from __future__ import annotations
-
-from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Generic, TypeVar
-
-from .base import AgentTool
-
-if TYPE_CHECKING:
- from hud.agents.tools.capabilities import EnvironmentCapability
-
-ToolT = TypeVar("ToolT", bound=AgentTool[Any])
-
-
-@dataclass(frozen=True)
-class AgentToolRegistry(Generic[ToolT]):
- """Declarative registry for a provider or harness tool family."""
-
- tool_classes: tuple[type[ToolT], ...]
- name_fallbacks: dict[str, tuple[str, ...]] = field(default_factory=dict)
-
- @property
- def capabilities(self) -> frozenset[str]:
- return frozenset(cls.capability for cls in self.tool_classes)
-
- def tool_for_capability(
- self,
- capability: EnvironmentCapability,
- model: str,
- ) -> ToolT | None:
- tools = self.tools_for_capability(capability, model)
- return tools[0] if tools else None
-
- def tools_for_capability(
- self,
- capability: EnvironmentCapability,
- model: str,
- ) -> list[ToolT]:
- tools: list[ToolT] = []
- for tool_cls in self.tool_classes:
- if tool_cls.capability != capability.name:
- continue
- spec = tool_cls.default_spec(model)
- if spec is None:
- continue
- env_tool_name_for_capability = getattr(tool_cls, "env_tool_name_for_capability", None)
- if (
- callable(env_tool_name_for_capability)
- and env_tool_name_for_capability(capability) is None
- ):
- continue
- tools.append(tool_cls.from_capability(capability, spec, model))
- return tools
-
-
-__all__ = ["AgentToolRegistry"]
diff --git a/hud/agents/types.py b/hud/agents/types.py
index cb48ed5d9..718fc7ab0 100644
--- a/hud/agents/types.py
+++ b/hud/agents/types.py
@@ -10,20 +10,22 @@
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
-from hud.types import BaseAgentConfig
+from hud.agents.tools.hosted import HostedTool
# Alias to accept both 'model' and 'checkpoint_name' (backwards compat)
_model_alias = AliasChoices("model", "checkpoint_name")
-class BaseCreateParams(BaseModel):
- """Runtime parameters for agent creation."""
-
+class AgentConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
ctx: Any = None # EvalContext or Environment
auto_respond: bool = False
- verbose: bool = False
+ system_prompt: str | None = None
+ hosted_tools: list[HostedTool[object]] = Field(default_factory=list[HostedTool[object]])
+
+ model_name: str = "Agent"
+ model: str = Field(default="unknown", validation_alias=_model_alias)
# -----------------------------------------------------------------------------
@@ -31,9 +33,7 @@ class BaseCreateParams(BaseModel):
# -----------------------------------------------------------------------------
-class ClaudeConfig(BaseAgentConfig):
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
+class ClaudeConfig(AgentConfig):
model_name: str = "Claude"
model: str = Field(default="claude-sonnet-4-6", validation_alias=_model_alias)
model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock
@@ -42,23 +42,17 @@ class ClaudeConfig(BaseAgentConfig):
validate_api_key: bool = True
-class ClaudeCreateParams(BaseCreateParams, ClaudeConfig):
- pass
-
-
# -----------------------------------------------------------------------------
# Gemini
# -----------------------------------------------------------------------------
-class GeminiConfig(BaseAgentConfig):
+class GeminiConfig(AgentConfig):
"""Configuration for GeminiAgent."""
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
model_name: str = "Gemini"
model: str = Field(default="gemini-3-pro-preview", validation_alias=_model_alias)
- model_client: Any = None # genai.Client
+ model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock
temperature: float = 1.0
top_p: float = 0.95
top_k: int = 40
@@ -69,23 +63,17 @@ class GeminiConfig(BaseAgentConfig):
include_thoughts: bool = True
-class GeminiCreateParams(BaseCreateParams, GeminiConfig):
- pass
-
-
# -----------------------------------------------------------------------------
# OpenAI
# -----------------------------------------------------------------------------
-class OpenAIConfig(BaseAgentConfig):
+class OpenAIConfig(AgentConfig):
"""Configuration for OpenAIAgent."""
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
model_name: str = "OpenAI"
model: str = Field(default="gpt-5.4", validation_alias=_model_alias)
- model_client: Any = None # AsyncOpenAI
+ model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock
max_output_tokens: int | None = None
temperature: float | None = None
reasoning: Any = None # openai Reasoning
@@ -96,15 +84,9 @@ class OpenAIConfig(BaseAgentConfig):
validate_api_key: bool = True
-class OpenAICreateParams(BaseCreateParams, OpenAIConfig):
- pass
-
-
-class OpenAIChatConfig(BaseAgentConfig):
+class OpenAIChatConfig(AgentConfig):
"""Configuration for OpenAIChatAgent."""
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
model_name: str = "OpenAI Chat"
model: str = Field(default="gpt-5-mini", validation_alias=_model_alias)
checkpoint: str | None = Field(
@@ -118,7 +100,3 @@ class OpenAIChatConfig(BaseAgentConfig):
api_key: str | None = None
base_url: str | None = None
completion_kwargs: dict[str, Any] = Field(default_factory=dict)
-
-
-class OpenAIChatCreateParams(BaseCreateParams, OpenAIChatConfig):
- pass
diff --git a/hud/cli/rl.py b/hud/cli/rl.py
index a3831e0a5..538d5a65b 100644
--- a/hud/cli/rl.py
+++ b/hud/cli/rl.py
@@ -24,7 +24,7 @@
# =============================================================================
-async def _fetch_env_metadata(env_name: str, headers: dict[str, str]) -> dict[str, Any] | None:
+async def _fetch_tool_metadata(env_name: str, headers: dict[str, str]) -> dict[str, Any] | None:
"""Fetch env metadata from mcp-config endpoint. Returns response dict or None."""
url = f"{settings.hud_api_url}/environments/{env_name}/mcp-config"
async with httpx.AsyncClient(timeout=15.0) as client:
@@ -116,20 +116,20 @@ async def _preflight_validate(tasks: list[Any]) -> None:
hud_console.info(f"Preflight: checking {len(env_names)} environment(s)…")
- env_metadata: dict[str, dict[str, Any]] = {}
+ tool_metadata: dict[str, dict[str, Any]] = {}
for name in sorted(env_names):
- data = await _fetch_env_metadata(name, headers)
+ data = await _fetch_tool_metadata(name, headers)
if data is None:
hud_console.error(f"Environment '{name}' not found on platform")
hud_console.hint("Deploy it first with: hud deploy")
raise typer.Exit(1)
- env_metadata[name] = data
+ tool_metadata[name] = data
hud_console.info(f" ✓ {name}")
env_scenarios = _extract_scenarios(tasks)
for env_name, scenarios in sorted(env_scenarios.items()):
- if env_name in env_metadata:
- _check_scenarios(env_name, scenarios, env_metadata[env_name])
+ if env_name in tool_metadata:
+ _check_scenarios(env_name, scenarios, tool_metadata[env_name])
hud_console.success("Preflight passed")
diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py
index e46f5c9ca..4d9320d33 100644
--- a/hud/cli/tests/test_eval.py
+++ b/hud/cli/tests/test_eval.py
@@ -44,6 +44,7 @@ def __init__(
self.error: BaseException | None = None
self.metadata: dict[str, Any] = {}
self._is_summary = False
+ self._scenario_sessions = {}
def as_tools(self) -> list[types.Tool]:
return self._tools
diff --git a/hud/cli/utils/version_check.py b/hud/cli/utils/version_check.py
index 301053ac6..5ae9d07df 100644
--- a/hud/cli/utils/version_check.py
+++ b/hud/cli/utils/version_check.py
@@ -232,7 +232,7 @@ def display_update_prompt(console: HUDConsole | None = None) -> None:
console: HUDConsole instance for output. If None, creates a new one.
"""
if console is None:
- console = HUDConsole(logger=logger)
+ console = HUDConsole()
try:
info = check_for_updates()
diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py
index 89b4dc704..49a23e448 100644
--- a/hud/datasets/runner.py
+++ b/hud/datasets/runner.py
@@ -145,7 +145,7 @@ async def run_dataset(
# Create agent using AgentType.cls.create()
agent = agent_type.cls.create(**final_agent_params)
- await agent.run(ctx, max_steps=max_steps)
+ await ctx._run(agent, max_steps=max_steps)
# Reward is computed by EvalContext.__aexit__ from the scenario evaluate phase.
# For parallel execution, results are collected via ctx.results
@@ -252,7 +252,7 @@ async def run_single_task(
if metadata:
ctx.metadata.update(metadata)
- result = await agent.run(ctx, max_steps=max_steps)
+ result = await ctx._run(agent, max_steps=max_steps)
# Reward is computed by EvalContext.__aexit__ from the scenario evaluate phase.
# Propagate reward from EvalContext (set in __aexit__) to returned Trace
diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py
index b7f064d20..25c75869b 100644
--- a/hud/datasets/utils.py
+++ b/hud/datasets/utils.py
@@ -38,8 +38,8 @@ class SingleTaskRequest(BaseModel):
agent_params: dict[str, Any] = Field(
default_factory=dict,
description="Agent constructor parameters passed to agent.create(). "
- "Should include fields from BaseCreateParams (auto_trace, auto_respond, verbose) "
- "plus agent-specific config fields (e.g., checkpoint_name for ClaudeConfig).",
+ "Should include runtime fields (ctx, auto_respond) plus agent-specific "
+ "config fields (e.g., checkpoint_name for ClaudeConfig).",
)
max_steps: int = Field(default=10, description="Maximum steps allowed for the agent.")
job_id: str = Field(description="HUD job identifier for telemetry association.")
diff --git a/hud/environment/environment.py b/hud/environment/environment.py
index abeff5d8f..3a566475e 100644
--- a/hud/environment/environment.py
+++ b/hud/environment/environment.py
@@ -986,7 +986,7 @@ async def checkout(user_id: str):
# Single task via hud.eval
async with hud.eval(env("checkout", user_id="alice")) as ctx:
- await agent.run(ctx.prompt)
+ await ctx._run(agent)
# Multiple tasks with variants
tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")]
diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py
index 17ed1e062..5849afd93 100644
--- a/hud/environment/scenarios.py
+++ b/hud/environment/scenarios.py
@@ -246,6 +246,18 @@ def _to_prompt_message(item: Any, default_role: str = "user") -> PromptMessage:
role=item.role, # type: ignore[arg-type]
content=TextContent(type="text", text=str(item.content)),
)
+ if hasattr(item, "content"):
+ role = getattr(item, "role", default_role)
+ content = item.content
+ if isinstance(content, str):
+ content = TextContent(type="text", text=content)
+ elif isinstance(content, TextContent) or hasattr(content, "type"):
+ pass
+ elif hasattr(content, "text"):
+ content = TextContent(type="text", text=str(content.text))
+ else:
+ content = TextContent(type="text", text=str(content))
+ return PromptMessage(role=role, content=content) # type: ignore[arg-type]
if isinstance(item, str):
return PromptMessage(
role=default_role, # type: ignore[arg-type]
@@ -294,6 +306,22 @@ def _build_answer_for_generator(session: ScenarioSession) -> Any:
elif isinstance(raw_answer, str):
raw_text = raw_answer
raw_citations = []
+ text = raw_answer.strip()
+ if text.startswith("```"):
+ parts = text.split("```")
+ if len(parts) >= 3:
+ text = parts[1].removeprefix("json").strip()
+ try:
+ parsed_answer = json.loads(text)
+ except (json.JSONDecodeError, TypeError):
+ parsed_answer = None
+ if isinstance(parsed_answer, dict) and (
+ "content" in parsed_answer or "citations" in parsed_answer
+ ):
+ content = parsed_answer.get("content", "")
+ raw_text = content if isinstance(content, str) else json.dumps(content)
+ citations = parsed_answer.get("citations", [])
+ raw_citations = [c for c in citations if isinstance(c, dict)]
else:
raw_text = str(raw_answer) if raw_answer is not None else ""
raw_citations = []
@@ -741,10 +769,13 @@ async def run_scenario_setup(
# Prompt exists remotely; original setup/rendering error.
raise
- # Extract prompt text from response
+ # Extract prompt messages and text from response
+ prompt_messages = (
+ _normalize_prompt_yield(list(result.messages)) if result.messages else None
+ )
prompt_text: str | None = None
- if result.messages:
- first_msg = result.messages[0]
+ if prompt_messages:
+ first_msg = prompt_messages[0]
content = first_msg.content
if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr]
prompt_text = content.text # type: ignore[union-attr]
@@ -793,6 +824,7 @@ async def run_scenario_setup(
allowed_tools=allowed_tools_meta,
returns_schema=returns_schema_meta,
enable_citations=enable_citations_meta,
+ prompt_messages=prompt_messages,
)
self._set_session(session, session_id)
diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py
index 2133823b8..4e0567940 100644
--- a/hud/environment/tests/test_environment.py
+++ b/hud/environment/tests/test_environment.py
@@ -321,10 +321,12 @@ async def investigate(issue: str):
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
+ from hud.types import Trace
+
+ mock_ctx._run.return_value = Trace(content="subagent output", done=True)
mock_run_eval.return_value = mock_ctx
mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(content="subagent output"))
mock_create_agent.return_value = mock_agent
req_meta = RequestParams.Meta.model_validate({"_hud_trace_id": "trace-from-meta"})
req_context = RequestContext(
diff --git a/hud/environment/tests/test_scenarios.py b/hud/environment/tests/test_scenarios.py
index a646a3bd7..ca6256970 100644
--- a/hud/environment/tests/test_scenarios.py
+++ b/hud/environment/tests/test_scenarios.py
@@ -1355,6 +1355,41 @@ async def typed_scenario():
assert isinstance(prompt.meta, dict)
assert prompt.meta.get("enable_citations") is True
+ @pytest.mark.asyncio
+ async def test_structured_answer_parses_json_wrapped_content_and_citations(self) -> None:
+ """Structured scenario parsing unwraps model-emitted content/citations JSON."""
+ env = Environment("test-env")
+
+ class Answer(BaseModel):
+ final: str
+
+ captured = None
+
+ @env.scenario("typed", returns=Answer, enable_citations=True)
+ async def typed_scenario():
+ nonlocal captured
+ captured = yield "Prompt"
+ yield 1.0
+
+ await env.run_scenario_setup("typed", {})
+ await env.submit(
+ "typed",
+ """```json
+{
+ "content": {"final": "done"},
+ "citations": [
+ {"type": "url_citation", "source": "https://example.com", "text": "source"}
+ ]
+}
+```""",
+ )
+ result = await env.run_scenario_evaluate("typed")
+
+ assert result.reward == 1.0
+ assert captured is not None
+ assert captured.content.final == "done"
+ assert captured.citations[0].source == "https://example.com"
+
@pytest.mark.asyncio
async def test_submit_before_setup_raises(self) -> None:
"""Calling submit() before run_scenario_setup() should raise."""
diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py
index 119769fcc..8812ce8c8 100644
--- a/hud/eval/__init__.py
+++ b/hud/eval/__init__.py
@@ -13,12 +13,12 @@
await ctx.call_tool("navigate", url="...")
async with env("checkout", user_id="alice") as ctx:
- await agent.run(ctx.prompt)
+ await ctx.submit("answer")
# Orchestrated with Task objects
tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")]
async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx:
- await agent.run(ctx.prompt)
+ await ctx._run(agent)
# Blank eval for manual reward
async with hud.eval() as ctx:
diff --git a/hud/eval/context.py b/hud/eval/context.py
index f767a5bbc..05d7ce8d5 100644
--- a/hud/eval/context.py
+++ b/hud/eval/context.py
@@ -13,8 +13,12 @@
import logging
import uuid
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any, Self
+from typing import TYPE_CHECKING, Any, Literal, Self, cast
+import mcp.types as types
+
+from hud.agents.base import AgentContext
+from hud.agents.tools.base import ToolClient
from hud.environment import Environment
from hud.settings import settings
from hud.shared import make_request
@@ -24,15 +28,17 @@
from collections.abc import Generator
from types import TracebackType
+ from hud.agents.tools import CapabilityEntry, ToolMetadata
from hud.eval.task import Task
from hud.tools.types import EvaluationResult
- from hud.types import MCPToolResult
+ from hud.types import MCPToolResult, Trace
from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete
logger = logging.getLogger(__name__)
+
# Contextvar to store current trace headers (for httpx auto-instrumentation)
_current_trace_headers: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar(
"current_trace_headers", default=None
@@ -109,7 +115,7 @@ class EvalContext(Environment):
# With task (scenario sets reward automatically)
tasks = load_tasks("my-org/task:1")
async with hud.eval(tasks) as ctx:
- await agent.run(ctx)
+ await ctx._run(agent)
# reward set by scenario evaluate phase in __aexit__
# Blank eval (manual reward)
@@ -174,7 +180,7 @@ def __init__(
self.answer: str | dict[str, Any] | None = None # Agent's submitted answer
self.system_prompt: str | None = None # From task.agent_config, passed to agent
self.scenario_returns_schema: dict[str, Any] | None = None
- self.scenario_enable_citations: bool = False
+ self.enable_citations: bool = False
# Error tracking
self.error: BaseException | None = None
@@ -374,18 +380,9 @@ async def _run_task_scenario_setup(self) -> None:
if prompt:
self.prompt = prompt
- # If scenario yielded multi-turn messages, store as conversation
session = self._get_session()
self.scenario_returns_schema = session.returns_schema if session else None
- self.scenario_enable_citations = bool(session.enable_citations) if session else False
- if session and session.prompt_messages and len(session.prompt_messages) > 1:
- self.conversation = [
- {
- "role": pm.role,
- "content": getattr(pm.content, "text", str(pm.content)),
- }
- for pm in session.prompt_messages
- ]
+ self.enable_citations = bool(session.enable_citations) if session else False
async def _run_task_scenario_evaluate(self) -> None:
"""Run the task's scenario evaluate phase (if scenario provided)."""
@@ -511,8 +508,7 @@ async def submit(self, answer: str | dict[str, Any]) -> None:
Example:
async with env("checkout", product="laptop") as ctx:
- response = await agent.run(ctx.prompt)
- await ctx.submit(response)
+ await ctx.submit("answer")
# On exit, scenario's evaluate phase receives the answer
"""
if not self._task or not self._task.scenario:
@@ -524,6 +520,90 @@ async def submit(self, answer: str | dict[str, Any]) -> None:
# Delegate to Environment.submit() which handles storage + broadcast
await super().submit(self._task.scenario, answer)
+ async def submit_result(self, result: Trace) -> None:
+ """Record an agent result on the eval context."""
+ if result.isError:
+ error_msg = result.info.get("error") if result.info else result.content
+ self.error = Exception(str(error_msg)) if error_msg else Exception("Agent error")
+ return
+
+ if not result.content:
+ return
+
+ if result.citations:
+ await self.submit({"content": result.content, "citations": result.citations})
+ else:
+ await self.submit(result.content)
+
+ async def _run(self, agent: Any, *, max_steps: int = 10) -> Trace:
+ """Run an agent against this eval context."""
+ await self.list_tools()
+ initial_messages = self.prompt_messages()
+ tool_client = ToolClient(
+ tools=self.as_tools(),
+ tool_handler=self.call_tool,
+ tool_metadata=self._tool_metadata(),
+ )
+
+ agent.enable_citations = bool(getattr(self, "enable_citations", False))
+ result = await agent.run(
+ AgentContext(
+ messages=initial_messages,
+ tool_client=tool_client,
+ ),
+ max_steps=max_steps,
+ )
+ await self.submit_result(result)
+ return result
+
+ def _tool_metadata(self) -> ToolMetadata | None:
+ if environment_capabilities := self.metadata.get("environment_capabilities"):
+ return cast("ToolMetadata", environment_capabilities)
+ if capabilities := self.metadata.get("capabilities"):
+ return {"capabilities": cast("dict[str, str | CapabilityEntry]", capabilities)}
+ return None
+
+ def prompt_messages(self) -> list[types.PromptMessage]:
+ """Return raw MCP prompt messages for an agent run."""
+ session = self._get_session()
+ if session and session.prompt_messages:
+ return session.prompt_messages
+
+ conversation = getattr(self, "conversation", None)
+ if conversation:
+ messages: list[types.PromptMessage] = []
+ for msg in conversation:
+ role = cast("Literal['user', 'assistant']", msg.get("role", "user"))
+ messages.append(
+ types.PromptMessage(
+ role=role,
+ content=types.TextContent(type="text", text=msg.get("content", "")),
+ )
+ )
+ return messages
+
+ prompt = getattr(self, "prompt", None)
+ if not prompt:
+ if self.has_scenario:
+ scenario = self._task.scenario if self._task else "unknown"
+ raise ValueError(
+ f"ctx.prompt is not set.\n\n"
+ f"Scenario '{scenario}' was specified but returned an empty prompt.\n"
+ f"Check that the scenario's setup function returns a non-empty string."
+ )
+ raise ValueError(
+ "ctx.prompt is not set.\n\n"
+ "No scenario was specified in your task file.\n"
+ "Add a 'scenario' field to your task so scenario setup can produce a prompt."
+ )
+
+ return [
+ types.PromptMessage(
+ role="user",
+ content=types.TextContent(type="text", text=prompt),
+ )
+ ]
+
async def _eval_enter(self) -> None:
"""Notify backend that eval has started."""
if not self._trace_enabled:
diff --git a/hud/eval/manager.py b/hud/eval/manager.py
index 655833e2e..7b627cc4e 100644
--- a/hud/eval/manager.py
+++ b/hud/eval/manager.py
@@ -148,12 +148,12 @@ async def run_eval(
env = Environment("my-env").connect_hub("browser")
tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")]
async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx:
- await agent.run(ctx.prompt)
+ await ctx._run(agent)
# Load tasks from file or API
tasks = load_tasks("hud-evals/SheetBench-50")
async with hud.eval(tasks) as ctx:
- await agent.run(ctx)
+ await ctx._run(agent)
# With variants and group
async with hud.eval(
@@ -167,7 +167,7 @@ async def run_eval(
# With concurrency limit
async with hud.eval(tasks, max_concurrent=10) as ctx:
- await agent.run(ctx)
+ await ctx._run(agent)
# Access results after parallel run
for e in ctx.results:
diff --git a/hud/eval/task.py b/hud/eval/task.py
index e13159919..fefcbec73 100644
--- a/hud/eval/task.py
+++ b/hud/eval/task.py
@@ -15,7 +15,7 @@
# With scenario
async with env("checkout", user_id="alice") as ctx:
- await agent.run(ctx.prompt)
+ await ctx.submit("answer")
# Orchestrated via hud.eval
tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")]
@@ -279,7 +279,7 @@ async def run(
agent = create_agent(agent)
async with run_eval(self, trace=trace, quiet=quiet) as ctx:
- result = await agent.run(ctx, max_steps=max_steps)
+ result = await ctx._run(agent, max_steps=max_steps)
if ctx.reward is not None:
result.reward = ctx.reward
diff --git a/hud/services/chat.py b/hud/services/chat.py
index 50177db6c..bd53111ca 100644
--- a/hud/services/chat.py
+++ b/hud/services/chat.py
@@ -89,7 +89,7 @@ class Chat(AgentExecutor):
Each ``send()`` call:
1. Appends the user message to history
2. Creates a Task copy with the full history as scenario args
- 3. Runs ``hud.eval(task)`` -> scenario setup -> ``agent.run(ctx)`` -> evaluate
+ 3. Runs ``hud.eval(task)`` -> scenario setup -> ``ctx._run(agent)`` -> evaluate
4. Appends the assistant response to history
5. Returns the Trace
diff --git a/hud/tests/public_api/test_v5_legacy_aliases.py b/hud/tests/public_api/test_v5_legacy_aliases.py
index 8e94cc281..ea8f3e633 100644
--- a/hud/tests/public_api/test_v5_legacy_aliases.py
+++ b/hud/tests/public_api/test_v5_legacy_aliases.py
@@ -54,12 +54,6 @@ def fake_load_tasks(source: str, *, raw: bool = False) -> list[dict[str, str]]:
assert calls == [("local-or-remote-source", True)]
-def test_agent_response_aliases_inference_result() -> None:
- import hud.types as types
-
- assert types.AgentResponse is types.InferenceResult
-
-
def test_tool_router_aliases_environment_mcp_router() -> None:
import hud.environment as environment
diff --git a/hud/tests/public_api/test_v5_surface_imports.py b/hud/tests/public_api/test_v5_surface_imports.py
index 15cf0f43f..57b6ab8d6 100644
--- a/hud/tests/public_api/test_v5_surface_imports.py
+++ b/hud/tests/public_api/test_v5_surface_imports.py
@@ -81,8 +81,8 @@
"PlaywrightTool",
),
"hud.types": (
+ "AgentResponse",
"AgentType",
- "InferenceResult",
"MCPToolCall",
"MCPToolResult",
"Trace",
@@ -97,12 +97,7 @@
"OpenAIChatAgent",
"create_agent",
),
- "hud.agents.claude": (
- "ClaudeAgent",
- "base64_to_content_block",
- "text_to_content_block",
- "tool_use_content_block",
- ),
+ "hud.agents.claude": ("ClaudeAgent",),
"hud.datasets": (
"display_results",
"load_tasks",
@@ -215,7 +210,6 @@
"hud.tools.agent": ("AgentTool",),
"hud.agents.gemini": ("GeminiAgent",),
"hud.agents.openai": ("OpenAIAgent",),
- "hud.agents.openai_chat": ("OpenAIChatAgent",),
"hud.tools.coding": (
"ApplyPatchTool",
"BashTool",
diff --git a/hud/tests/public_api/test_v5_workflow_contracts.py b/hud/tests/public_api/test_v5_workflow_contracts.py
index cd9df4819..2491baba8 100644
--- a/hud/tests/public_api/test_v5_workflow_contracts.py
+++ b/hud/tests/public_api/test_v5_workflow_contracts.py
@@ -650,7 +650,7 @@ def test_native_grader_helpers_keep_basic_semantics() -> None:
assert f1_score("hello hud", "hello sdk") == 0.5
-def test_eval_context_user_facing_properties_and_tool_surface() -> None:
+def test_eval_context_user_facing_properties_and_tool_helpers() -> None:
ctx = EvalContext(trace=False, quiet=True, variants={"model": "test"})
ctx.prompt = "Do the task"
diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py
index 8ddf6a9fc..d870c5a62 100644
--- a/hud/tests/test_datasets_extended.py
+++ b/hud/tests/test_datasets_extended.py
@@ -39,6 +39,7 @@ async def test_run_dataset_with_task_list(self):
mock_ctx = AsyncMock()
mock_ctx.results = None
mock_ctx.reward = None
+ mock_ctx._run.return_value = Trace(reward=1.0, done=True)
# Create mock agent class and instance (use MagicMock since create() is sync)
mock_agent_instance = AsyncMock()
@@ -57,7 +58,7 @@ async def test_run_dataset_with_task_list(self):
# Should return list with ctx
assert len(results) == 1
- mock_agent_instance.run.assert_called_once()
+ mock_ctx._run.assert_called_once_with(mock_agent_instance, max_steps=5)
@pytest.mark.asyncio
async def test_run_dataset_from_source_string(self):
@@ -70,6 +71,7 @@ async def test_run_dataset_from_source_string(self):
mock_ctx = AsyncMock()
mock_ctx.results = None
+ mock_ctx._run.return_value = Trace(reward=1.0, done=True)
# Create mock agent class and instance (use MagicMock since create() is sync)
mock_agent_instance = AsyncMock()
@@ -101,6 +103,7 @@ async def test_run_dataset_passes_parameters(self):
mock_ctx = AsyncMock()
mock_ctx.results = None
+ mock_ctx._run.return_value = Trace(reward=1.0, done=True)
# Create mock agent class and instance (use MagicMock since create() is sync)
mock_agent_instance = AsyncMock()
diff --git a/hud/tests/test_types.py b/hud/tests/test_types.py
index 55a5c0f89..bc1147ffe 100644
--- a/hud/tests/test_types.py
+++ b/hud/tests/test_types.py
@@ -4,7 +4,7 @@
from mcp.types import ImageContent, TextContent
-from hud.types import InferenceResult, MCPToolCall, MCPToolResult, Trace, TraceStep
+from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace, TraceStep
def test_mcp_tool_call_str_long_args():
@@ -164,17 +164,17 @@ def test_mcp_tool_result_rich():
mock_console.format_tool_result.assert_called_once()
-def test_inference_result_str_with_reasoning():
- """Test InferenceResult __str__ includes reasoning."""
- response = InferenceResult(reasoning="Test reasoning", content="Test content")
+def test_agent_response_str_with_reasoning():
+ """Test AgentResponse __str__ includes reasoning."""
+ response = AgentResponse(reasoning="Test reasoning", content="Test content")
output = str(response)
assert "Reasoning: Test reasoning" in output
assert "Content: Test content" in output
-def test_inference_result_str_with_tool_calls():
- """Test InferenceResult __str__ includes tool calls."""
- response = InferenceResult(
+def test_agent_response_str_with_tool_calls():
+ """Test AgentResponse __str__ includes tool calls."""
+ response = AgentResponse(
tool_calls=[
MCPToolCall(name="tool1", arguments={"a": 1}),
MCPToolCall(name="tool2", arguments={"b": 2}),
@@ -186,38 +186,29 @@ def test_inference_result_str_with_tool_calls():
assert "tool2" in output
-def test_inference_result_str_with_raw():
- """Test InferenceResult __str__ includes raw."""
- response = InferenceResult(raw={"raw_data": "value"})
+def test_agent_response_str_with_raw():
+ """Test AgentResponse __str__ includes raw."""
+ response = AgentResponse(raw={"raw_data": "value"})
output = str(response)
assert "Raw:" in output
-def test_inference_result_citations_default_empty():
- """InferenceResult.citations defaults to empty list."""
- result = InferenceResult(content="hello")
+def test_agent_response_citations_default_empty():
+ """AgentResponse.citations defaults to empty list."""
+ result = AgentResponse(content="hello")
assert result.citations == []
-def test_inference_result_citations_roundtrip():
+def test_agent_response_citations_roundtrip():
"""Citations survive serialize/deserialize."""
cit = {"type": "url_citation", "source": "https://example.com", "title": "Example"}
- result = InferenceResult(content="hello", citations=[cit])
+ result = AgentResponse(content="hello", citations=[cit])
data = result.model_dump(mode="json")
- restored = InferenceResult(**data)
+ restored = AgentResponse(**data)
assert len(restored.citations) == 1
assert restored.citations[0]["source"] == "https://example.com"
-def test_agent_response_alias():
- """AgentResponse is a backwards-compatible alias for InferenceResult."""
- from hud.types import AgentResponse
-
- assert AgentResponse is InferenceResult
- r = AgentResponse(content="test", done=True)
- assert isinstance(r, InferenceResult)
-
-
def test_trace_citations_default_empty():
"""Trace.citations defaults to empty list."""
trace = Trace()
diff --git a/hud/tools/agent.py b/hud/tools/agent.py
index 0d8743fa4..dd5646015 100644
--- a/hud/tools/agent.py
+++ b/hud/tools/agent.py
@@ -216,7 +216,7 @@ async def _run_subagent() -> ToolResult:
else:
agent = self._agent_cls.create(**self._agent_params) # type: ignore
- result = await agent.run(ctx)
+ result = await ctx._run(agent)
content = result.content if hasattr(result, "content") and result.content else ""
return ToolResult(content=[TextContent(type="text", text=content)])
diff --git a/hud/tools/computer/base.py b/hud/tools/computer/base.py
index 9dbe2a27d..cf6770012 100644
--- a/hud/tools/computer/base.py
+++ b/hud/tools/computer/base.py
@@ -88,12 +88,14 @@ def __init__(
self.height = height or computer_settings.DISPLAY_HEIGHT
# Build metadata with resolution info
- meta = {
+ meta: dict[str, object] = {
"resolution": {
"width": self.width,
"height": self.height,
}
}
+ if coordinate_space is not None:
+ meta["coordinate_space"] = coordinate_space
# Initialize base tool with executor as env
super().__init__(
diff --git a/hud/tools/computer/settings.py b/hud/tools/computer/settings.py
index 8d3121500..51a0201ce 100644
--- a/hud/tools/computer/settings.py
+++ b/hud/tools/computer/settings.py
@@ -93,11 +93,6 @@ class ComputerSettings(BaseSettings):
description="Whether to rescale images to the agent width and height",
validation_alias="GEMINI_RESCALE_IMAGES",
)
- GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS: int = Field(
- default=3,
- description="Maximum number of recent turns to keep screenshots for in Gemini agent",
- validation_alias="GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS",
- )
GLM_COMPUTER_WIDTH: int = Field(
default=1024,
description="Width of the display to use for the z-ai/glm4.5v computer tools",
diff --git a/hud/tools/tests/test_agent_tool.py b/hud/tools/tests/test_agent_tool.py
index de8196c38..d85523801 100644
--- a/hud/tools/tests/test_agent_tool.py
+++ b/hud/tools/tests/test_agent_tool.py
@@ -1,220 +1,64 @@
-"""Tests for AgentTool - scenario-to-agent composition."""
+"""Tests for AgentTool's public tool schema behavior."""
from __future__ import annotations
-import inspect
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import MagicMock
import pytest
from hud.environment import Environment
from hud.eval.task import Task
-from hud.tools.agent import AgentTool, _is_eval_only
-
-
-class TestIsEvalOnly:
- """Tests for _is_eval_only helper function."""
-
- def test_required_param_not_eval_only(self) -> None:
- """Required params (no default) are not eval-only."""
-
- def fn(x: str) -> None:
- pass
-
- sig = inspect.signature(fn)
- param = sig.parameters["x"]
- assert not _is_eval_only(param)
-
- def test_optional_with_value_not_eval_only(self) -> None:
- """Optional params with non-None default are not eval-only."""
-
- def fn(x: str = "default") -> None:
- pass
-
- sig = inspect.signature(fn)
- param = sig.parameters["x"]
- assert not _is_eval_only(param)
-
- def test_optional_none_without_union_not_eval_only(self) -> None:
- """Optional with None default but no None in type is not eval-only."""
-
- def fn(x: str = None) -> None: # type: ignore[assignment] # noqa: RUF013
- pass
-
- sig = inspect.signature(fn)
- param = sig.parameters["x"]
- assert not _is_eval_only(param)
-
- def test_optional_none_with_union_is_eval_only(self) -> None:
- """Params with `X | None = None` pattern are eval-only."""
-
- def fn(x: str | None = None) -> None:
- pass
-
- sig = inspect.signature(fn)
- param = sig.parameters["x"]
- assert _is_eval_only(param)
-
- def test_optional_int_none_is_eval_only(self) -> None:
- """Works with int | None = None too."""
-
- def fn(x: int | None = None) -> None:
- pass
-
- sig = inspect.signature(fn)
- param = sig.parameters["x"]
- assert _is_eval_only(param)
-
- def test_string_annotation_with_none_union(self) -> None:
- """Handles string annotations like 'str | None'."""
- # Simulate string annotation
- param = inspect.Parameter(
- "x",
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- default=None,
- annotation="str | None",
- )
- assert _is_eval_only(param)
-
- def test_string_annotation_without_none(self) -> None:
- """String annotations without None are not eval-only."""
- param = inspect.Parameter(
- "x",
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- default=None,
- annotation="str",
- )
- assert not _is_eval_only(param)
+from hud.tools.agent import AgentTool
class TestAgentToolInit:
- """Tests for AgentTool initialization."""
-
def test_requires_model_or_agent(self) -> None:
- """Must provide either model or agent."""
task = Task(args={})
with pytest.raises(ValueError, match="Must provide either"):
AgentTool(task)
def test_cannot_provide_both_model_and_agent(self) -> None:
- """Cannot provide both model and agent."""
task = Task(args={})
mock_agent = MagicMock()
with pytest.raises(ValueError, match="Cannot provide both"):
AgentTool(task, model="claude", agent=mock_agent) # type: ignore[arg-type]
- def test_accepts_model_string(self) -> None:
- """Can create with model string."""
- task = Task(scenario="test", args={})
- tool = AgentTool(task, model="claude")
-
- assert tool._model == "claude"
- assert tool._agent_cls is None
-
- def test_accepts_agent_class(self) -> None:
- """Can create with custom agent class."""
- task = Task(scenario="test", args={})
- mock_agent_cls = MagicMock()
- tool = AgentTool(task, agent=mock_agent_cls) # type: ignore[arg-type]
-
- assert tool._model is None
- assert tool._agent_cls is mock_agent_cls
-
def test_name_defaults_to_scenario(self) -> None:
- """Tool name defaults to scenario name."""
task = Task(scenario="investigate", args={})
tool = AgentTool(task, model="claude")
assert tool.name == "investigate"
def test_name_can_be_overridden(self) -> None:
- """Tool name can be overridden."""
task = Task(scenario="investigate", args={})
tool = AgentTool(task, model="claude", name="custom_name")
assert tool.name == "custom_name"
-class TestAgentToolParamFiltering:
- """Tests for parameter filtering (eval-only params hidden)."""
-
- def test_filters_eval_only_params(self) -> None:
- """Eval-only params (| None = None) are filtered from visible_params."""
- env = Environment("test")
-
- # Use Union syntax for consistency across Python versions
- @env.scenario()
- async def investigate(
- issue_id: str,
- include_traces: bool = True,
- expected_cause: str | None = None, # Eval only
- ):
- yield {"task": f"Investigate {issue_id}"}
-
- task = env("investigate")
- tool = AgentTool(task, model="claude")
-
- # visible_params should only have issue_id and include_traces
- assert "issue_id" in tool._visible_params
- assert "include_traces" in tool._visible_params
- assert "expected_cause" not in tool._visible_params
-
- def test_all_required_params_visible(self) -> None:
- """All required params are visible."""
- env = Environment("test")
-
- @env.scenario()
- async def search(query: str, limit: int):
- yield {"task": f"Search: {query}"}
-
- task = env("search")
- tool = AgentTool(task, model="claude")
-
- assert "query" in tool._visible_params
- assert "limit" in tool._visible_params
-
- def test_optional_with_default_visible(self) -> None:
- """Optional params with non-None defaults are visible."""
- env = Environment("test")
-
- @env.scenario()
- async def fetch(url: str, request_timeout: int = 30, retries: int = 3):
- yield {"task": f"Fetch {url}"}
-
- task = env("fetch")
- tool = AgentTool(task, model="claude")
-
- assert "url" in tool._visible_params
- assert "request_timeout" in tool._visible_params
- assert "retries" in tool._visible_params
-
-
-class TestAgentToolSchema:
- """Tests for JSON schema generation."""
-
- def test_builds_json_schema(self) -> None:
- """Builds proper JSON schema from visible params."""
+class TestAgentToolMCP:
+ def test_mcp_tool_exposes_required_and_defaulted_scenario_parameters(self) -> None:
env = Environment("test")
@env.scenario()
- async def investigate(issue_id: str, verbose: bool = False):
- yield {"task": f"Investigate {issue_id}"}
+ async def investigate(issue_id: str, verbose: bool = False, limit: int = 10):
+ yield {"task": f"Investigate {issue_id} {verbose} {limit}"}
task = env("investigate")
tool = AgentTool(task, model="claude")
- schema = tool._param_schema
- assert schema is not None
+ schema = tool.mcp.parameters
assert schema["type"] == "object"
- assert "issue_id" in schema["properties"]
- assert "verbose" in schema["properties"]
+ assert set(schema["properties"]) == {"issue_id", "verbose", "limit"}
assert "issue_id" in schema["required"]
assert "verbose" not in schema["required"] # Has default
+ assert "limit" not in schema["required"]
+ assert schema["properties"]["verbose"]["default"] is False
+ assert schema["properties"]["limit"]["default"] == 10
- def test_schema_excludes_eval_only(self) -> None:
- """Schema excludes eval-only params."""
+ def test_mcp_tool_hides_eval_only_parameters(self) -> None:
env = Environment("test")
@env.scenario()
@@ -227,17 +71,11 @@ async def check(
task = env("check")
tool = AgentTool(task, model="claude")
- schema = tool._param_schema
- assert schema is not None
+ schema = tool.mcp.parameters
assert "item_id" in schema["properties"]
assert "expected_status" not in schema["properties"]
-
-class TestAgentToolMCP:
- """Tests for MCP tool integration."""
-
def test_mcp_property_returns_tool(self) -> None:
- """The mcp property returns a FastMCP FunctionTool."""
from fastmcp.tools import FunctionTool
env = Environment("test")
@@ -251,105 +89,3 @@ async def greet(name: str):
mcp_tool = tool.mcp
assert isinstance(mcp_tool, FunctionTool)
-
- def test_mcp_has_filtered_parameters(self) -> None:
- """MCP tool has filtered parameter schema."""
- env = Environment("test")
-
- @env.scenario()
- async def analyze(
- data: str,
- expected_result: str | None = None, # Eval only
- ):
- yield {"task": f"Analyze {data}"}
-
- task = env("analyze")
- tool = AgentTool(task, model="claude")
-
- mcp_tool = tool.mcp
- params = mcp_tool.parameters # FunctionTool uses 'parameters'
-
- assert "data" in params["properties"]
- assert "expected_result" not in params["properties"]
-
-
-class TestAgentToolCall:
- """Tests for AgentTool.__call__."""
-
- @pytest.mark.asyncio
- async def test_filters_kwargs_to_visible_only(self) -> None:
- """Call filters kwargs to visible params only."""
- # Import modules first so patches work
- import hud.agents
- import hud.eval.manager # noqa: F401
-
- env = Environment("test")
-
- @env.scenario()
- async def process(item: str, expected: str | None = None):
- yield {"task": f"Process {item}"}
-
- task = env("process")
- tool = AgentTool(task, model="claude")
-
- # Mock the eval context and agent
- with (
- patch("hud.eval.manager.run_eval") as mock_run_eval,
- patch("hud.agents.create_agent") as mock_create_agent,
- ):
- mock_ctx = AsyncMock()
- mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
- mock_ctx.__aexit__ = AsyncMock(return_value=None)
- mock_run_eval.return_value = mock_ctx
-
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(content="result"))
- mock_create_agent.return_value = mock_agent
-
- # Call with both visible and eval-only params
- await tool(item="test", expected="should_be_filtered")
-
- # Check that task was created with filtered args
- call_args = mock_run_eval.call_args
- task_arg = call_args[0][0]
- assert "item" in task_arg.args
- assert "expected" not in task_arg.args # Filtered out
-
- @pytest.mark.asyncio
- async def test_merges_template_args(self) -> None:
- """Call merges kwargs with template args."""
- # Import modules first so patches work
- import hud.agents
- import hud.eval.manager # noqa: F401
-
- env = Environment("test")
-
- @env.scenario()
- async def search(query: str, limit: int = 10):
- yield {"task": f"Search {query}"}
-
- # Create template with some args pre-filled
- task = env("search", limit=5)
- tool = AgentTool(task, model="claude")
-
- with (
- patch("hud.eval.manager.run_eval") as mock_run_eval,
- patch("hud.agents.create_agent") as mock_create_agent,
- ):
- mock_ctx = AsyncMock()
- mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
- mock_ctx.__aexit__ = AsyncMock(return_value=None)
- mock_run_eval.return_value = mock_ctx
-
- mock_agent = MagicMock()
- mock_agent.run = AsyncMock(return_value=MagicMock(content="result"))
- mock_create_agent.return_value = mock_agent
-
- # Call with additional arg
- await tool(query="test query")
-
- # Check merged args
- call_args = mock_run_eval.call_args
- task_arg = call_args[0][0]
- assert task_arg.args["query"] == "test query"
- assert task_arg.args["limit"] == 5 # From template
diff --git a/hud/tools/tests/test_coding_apply_patch.py b/hud/tools/tests/test_coding_apply_patch.py
index 1008c831d..e959dd5cc 100644
--- a/hud/tools/tests/test_coding_apply_patch.py
+++ b/hud/tools/tests/test_coding_apply_patch.py
@@ -1,4 +1,4 @@
-"""Tests for apply_patch compatibility tool and patch parser helpers."""
+"""Tests for the legacy apply_patch compatibility wrapper."""
from __future__ import annotations
@@ -8,15 +8,6 @@
import pytest
from mcp.types import TextContent
-from hud.agents.openai.tools.apply_patch import (
- ActionType,
- DiffError,
- Parser,
- _apply_commit,
- _identify_files_needed,
- _patch_to_commit,
- _text_to_patch,
-)
from hud.tools._legacy import ApplyPatchTool
from hud.tools.coding import EditTool
@@ -42,56 +33,3 @@ async def test_update_file_uses_edit_tool_behavior(self):
assert file_path.read_text() == "new\n"
assert isinstance(result[0], TextContent)
assert "written successfully" in result[0].text
-
-
-class TestPatchParser:
- """Focused tests for shared V4A parser helpers used by EditTool."""
-
- def test_parse_add_file(self):
- lines = [
- "*** Begin Patch",
- "*** Add File: new.txt",
- "+line 1",
- "+line 2",
- "*** End Patch",
- ]
- parser = Parser(current_files={}, lines=lines, index=1)
- parser.parse()
-
- action = parser.patch.actions["new.txt"]
- assert action.type == ActionType.ADD
- assert action.new_file == "line 1\nline 2"
-
- def test_parse_update_file(self):
- text = "*** Begin Patch\n*** Update File: test.txt\n@@\n-old\n+new\n*** End Patch"
-
- patch, fuzz = _text_to_patch(text, {"test.txt": "old\n"})
-
- assert fuzz == 0
- action = patch.actions["test.txt"]
- assert action.type == ActionType.UPDATE
-
- def test_identify_files_needed(self):
- text = "*** Begin Patch\n*** Update File: a.txt\n@@\n-old\n+new\n*** End Patch"
- assert _identify_files_needed(text) == ["a.txt"]
-
- def test_apply_commit_update(self):
- patch, _ = _text_to_patch(
- "*** Begin Patch\n*** Update File: a.txt\n@@\n-old\n+new\n*** End Patch",
- {"a.txt": "old\n"},
- )
- commit = _patch_to_commit(patch, {"a.txt": "old\n"})
- files = {"a.txt": "old\n"}
-
- def write(path: str, content: str | None) -> None:
- files[path] = content or ""
-
- def remove(path: str) -> None:
- del files[path]
-
- _apply_commit(commit, write, remove)
- assert files["a.txt"] == "new\n"
-
- def test_invalid_patch_raises(self):
- with pytest.raises(DiffError):
- _text_to_patch("not a patch", {})
diff --git a/hud/tools/tests/test_computer.py b/hud/tools/tests/test_computer.py
index b4d4c1c7c..4e2fce3d3 100644
--- a/hud/tools/tests/test_computer.py
+++ b/hud/tools/tests/test_computer.py
@@ -175,6 +175,7 @@ def test_glm_computer_is_legacy_generic_registration():
assert comp.name == "glm_computer"
assert "native_tools" not in comp.meta
+ assert comp.meta["coordinate_space"] == 999
@pytest.mark.asyncio
diff --git a/hud/types.py b/hud/types.py
index 7dd2e07ab..ae2d18b52 100644
--- a/hud/types.py
+++ b/hud/types.py
@@ -3,12 +3,29 @@
import json
import uuid
from enum import Enum
-from typing import Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import mcp.types as types
from mcp.types import CallToolRequestParams, CallToolResult
from pydantic import BaseModel, ConfigDict, Field
+if TYPE_CHECKING:
+ from hud.agents.claude import ClaudeAgent
+ from hud.agents.gemini import GeminiAgent
+ from hud.agents.openai import OpenAIAgent
+ from hud.agents.openai_compatible import OpenAIChatAgent
+ from hud.agents.types import ClaudeConfig, GeminiConfig, OpenAIChatConfig, OpenAIConfig
+
+ AgentClass: TypeAlias = type[ClaudeAgent | GeminiAgent | OpenAIAgent | OpenAIChatAgent]
+ AgentConfigClass: TypeAlias = type[
+ ClaudeConfig | GeminiConfig | OpenAIConfig | OpenAIChatConfig
+ ]
+ _AgentTypeInfo: TypeAlias = tuple[AgentClass, AgentConfigClass, str]
+
+# JSON-compatible scalar/container values.
+JsonValue: TypeAlias = str | int | float | bool | None | list["JsonValue"] | dict[str, "JsonValue"]
+JsonObject: TypeAlias = dict[str, JsonValue]
+
class AgentType(str, Enum):
CLAUDE = "claude"
@@ -17,29 +34,25 @@ class AgentType(str, Enum):
OPENAI_COMPATIBLE = "openai_compatible"
@property
- def cls(self) -> type:
- if self == AgentType.CLAUDE:
- from hud.agents.claude import ClaudeAgent
-
- return ClaudeAgent
- elif self == AgentType.OPENAI:
- from hud.agents import OpenAIAgent
+ def cls(self) -> AgentClass:
+ return self._info[0]
- return OpenAIAgent
- elif self == AgentType.GEMINI:
- from hud.agents.gemini import GeminiAgent
-
- return GeminiAgent
- elif self == AgentType.OPENAI_COMPATIBLE:
- from hud.agents.openai_compatible import OpenAIChatAgent
+ @property
+ def config_cls(self) -> AgentConfigClass:
+ """Get config class without importing agent (avoids SDK dependency)."""
+ return self._info[1]
- return OpenAIChatAgent
- else:
- raise ValueError(f"Unsupported agent type: {self}")
+ @property
+ def gateway_provider(self) -> str:
+ """Default provider client used when this agent type is a gateway shortcut."""
+ return self._info[2]
@property
- def config_cls(self) -> type:
- """Get config class without importing agent (avoids SDK dependency)."""
+ def _info(self) -> _AgentTypeInfo:
+ from hud.agents import OpenAIAgent
+ from hud.agents.claude import ClaudeAgent
+ from hud.agents.gemini import GeminiAgent
+ from hud.agents.openai_compatible import OpenAIChatAgent
from hud.agents.types import (
ClaudeConfig,
GeminiConfig,
@@ -47,24 +60,15 @@ def config_cls(self) -> type:
OpenAIConfig,
)
- mapping: dict[AgentType, type] = {
- AgentType.CLAUDE: ClaudeConfig,
- AgentType.OPENAI: OpenAIConfig,
- AgentType.GEMINI: GeminiConfig,
- AgentType.OPENAI_COMPATIBLE: OpenAIChatConfig,
- }
- if self not in mapping:
- raise ValueError(f"Unsupported agent type for config: {self}")
- return mapping[self]
-
-
-class BaseAgentConfig(BaseModel):
- """Agent configuration for LLM-specific settings."""
-
- model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True)
-
- system_prompt: str | None = None
- hosted_tools: list[Any] = Field(default_factory=list)
+ match self:
+ case AgentType.CLAUDE:
+ return ClaudeAgent, ClaudeConfig, "anthropic"
+ case AgentType.OPENAI:
+ return OpenAIAgent, OpenAIConfig, "openai"
+ case AgentType.GEMINI:
+ return GeminiAgent, GeminiConfig, "gemini"
+ case AgentType.OPENAI_COMPATIBLE:
+ return OpenAIChatAgent, OpenAIChatConfig, "openai"
class MCPToolCall(CallToolRequestParams):
@@ -72,6 +76,7 @@ class MCPToolCall(CallToolRequestParams):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) # Unique identifier for reference
annotation: str | None = None # Optional explanation of why this action is taken
+ provider_name: str | None = None # Original provider tool name when it differs from MCP name
def __str__(self) -> str:
"""Format tool call as plain text."""
@@ -149,8 +154,8 @@ def __rich__(self) -> str:
return hud_console.format_tool_result(content_summary, self.isError)
-class InferenceResult(BaseModel):
- """Result of a single LLM inference call.
+class AgentResponse(BaseModel):
+ """Result of a single agent inference call.
Returned by provider agents' ``get_response()`` methods. Carries the
model's text output, any tool calls it wants to make, and provider-
@@ -171,7 +176,7 @@ class InferenceResult(BaseModel):
# --- RESPONSE METADATA ---
# Populated by provider agents when citations are available.
- # Uses dict form of Citation (provider-normalized) so InferenceResult
+ # Uses dict form of Citation (provider-normalized) so AgentResponse
# doesn't depend on hud.tools.types at import time.
citations: list[dict[str, Any]] = Field(default_factory=list)
@@ -194,10 +199,6 @@ def __str__(self) -> str:
return response
-# Backwards-compatible alias (deprecated — use InferenceResult)
-AgentResponse = InferenceResult
-
-
class TraceStep(BaseModel):
"""Canonical data for a single span (shared with telemetry)."""
@@ -262,7 +263,7 @@ class Trace(BaseModel):
content: str | None = Field(default=None)
isError: bool = Field(default=False)
- # Response metadata carried from the final InferenceResult
+ # Response metadata carried from the final AgentResponse
citations: list[dict[str, Any]] = Field(default_factory=list)
# Metadata
@@ -296,7 +297,8 @@ def append(self, step: TraceStep) -> None:
"AgentResponse",
"AgentType",
"HudSpan",
- "InferenceResult",
+ "JsonObject",
+ "JsonValue",
"MCPToolCall",
"MCPToolResult",
"Task",
diff --git a/hud/utils/hud_console.py b/hud/utils/hud_console.py
index 041ea6753..17f526aec 100644
--- a/hud/utils/hud_console.py
+++ b/hud/utils/hud_console.py
@@ -621,81 +621,6 @@ def note(self, message: str, stderr: bool = True) -> None:
"""Print an important note with asterism symbol."""
self.symbol(Symbols.ITEM, message, GOLD, stderr)
- # ------------------------------------------------------------------
- # Agent-facing display methods
- # ------------------------------------------------------------------
-
- def format_tool_discovery(
- self,
- tools: list[Any],
- skipped: list[tuple[Any, str]] | None = None,
- stderr: bool = True,
- ) -> None:
- """Display a table of discovered tools on agent initialization.
-
- Args:
- tools: All available MCP tools
- skipped: List of (tool, reason) for skipped tools
- stderr: Output to stderr (default True)
- """
- console = self._stderr_console if stderr else self._stdout_console
-
- table = Table(
- show_header=True,
- box=None,
- padding=(0, 1),
- title=f"[{GOLD}]Discovered {len(tools)} tools[/{GOLD}]",
- title_style="",
- )
- table.add_column("Tool", style=TEXT, no_wrap=True)
- table.add_column("Available", style=DIM)
-
- for tool in tools:
- name = tool.name if hasattr(tool, "name") else str(tool)
- table.add_row(name, f"[{GREEN}]yes[/{GREEN}]")
-
- console.print(table)
-
- if skipped:
- for tool, reason in skipped:
- name = tool.name if hasattr(tool, "name") else str(tool)
- console.print(f" [{DIM}]⊘ {escape(name)}: {escape(reason)}[/{DIM}]")
-
- def format_step(
- self,
- step: int,
- max_steps: int,
- tool_calls: list[Any],
- tool_results: list[Any],
- elapsed: float | None = None,
- stderr: bool = True,
- ) -> None:
- """Display a compact step summary after tool execution.
-
- Args:
- step: Current step number
- max_steps: Maximum steps (-1 for unlimited)
- tool_calls: List of MCPToolCall objects
- tool_results: List of MCPToolResult objects
- elapsed: Step duration in seconds
- stderr: Output to stderr (default True)
- """
- console = self._stderr_console if stderr else self._stdout_console
-
- step_label = f"Step {step}"
- if max_steps != -1:
- step_label += f"/{max_steps}"
- if elapsed is not None:
- step_label += f" [{elapsed:.1f}s]"
-
- console.print(f"\n[bold {GOLD}]{step_label}[/bold {GOLD}]")
-
- for call, result in zip(tool_calls, tool_results, strict=False):
- call_str = str(call) if hasattr(call, "__rich__") else repr(call)
- result_str = str(result) if hasattr(result, "__rich__") else repr(result)
- console.print(f" {call_str}")
- console.print(f" {result_str}")
-
# Global design instance for convenience
class _ProgressContext:
diff --git a/pyproject.toml b/pyproject.toml
index 17ce20c36..dd20a3664 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -212,6 +212,7 @@ exclude = [
]
pythonVersion = "3.11"
typeCheckingMode = "basic"
+strict = ["hud/agents"]
reportMissingImports = "warning"
[tool.coverage.run]
@@ -248,4 +249,4 @@ testpaths = ["hud", "examples"]
addopts = ""
markers = [
"integration: marks tests as integration tests (require HUD_API_KEY, network access)",
-]
+]
\ No newline at end of file