From 9366a1a4f1ddee8db6962e48c2b0b0781aca32b6 Mon Sep 17 00:00:00 2001 From: Jaideep Date: Sun, 24 May 2026 01:17:11 -0700 Subject: [PATCH] agent updates --- docs/changelog.mdx | 2 +- docs/reference/agents.mdx | 7 +- docs/reference/cli/eval.mdx | 2 +- docs/reference/types.mdx | 6 +- hud/agents/__init__.py | 84 +- hud/agents/base.py | 691 ++----- hud/agents/claude/__init__.py | 10 - hud/agents/claude/agent.py | 741 +++----- hud/agents/claude/tools/__init__.py | 65 +- hud/agents/claude/tools/base.py | 176 +- hud/agents/claude/tools/coding.py | 73 +- hud/agents/claude/tools/computer.py | 175 +- hud/agents/claude/tools/hosted.py | 8 - hud/agents/claude/tools/memory.py | 16 +- hud/agents/claude/tools/settings.py | 2 - hud/agents/gateway.py | 114 +- hud/agents/gemini/agent.py | 635 ++----- hud/agents/gemini/settings.py | 21 + hud/agents/gemini/tools/__init__.py | 139 +- hud/agents/gemini/tools/base.py | 83 +- hud/agents/gemini/tools/coding.py | 48 +- hud/agents/gemini/tools/computer.py | 353 ++-- hud/agents/gemini/tools/filesystem.py | 51 +- hud/agents/gemini/tools/hosted.py | 8 - hud/agents/gemini/tools/memory.py | 15 +- hud/agents/misc/__init__.py | 4 +- hud/agents/misc/response_agent.py | 123 -- hud/agents/misc/response_automation.py | 113 ++ hud/agents/openai/agent.py | 563 ++---- hud/agents/openai/tools/__init__.py | 59 +- hud/agents/openai/tools/apply_patch.py | 329 ++-- hud/agents/openai/tools/base.py | 154 +- hud/agents/openai/tools/coding.py | 68 +- hud/agents/openai/tools/computer.py | 146 +- hud/agents/openai/tools/hosted.py | 7 - hud/agents/openai_compatible/__init__.py | 3 +- hud/agents/openai_compatible/agent.py | 421 ++--- .../openai_compatible/tools/__init__.py | 76 +- hud/agents/openai_compatible/tools/base.py | 180 ++ .../openai_compatible/tools/computer.py | 566 ------ .../openai_compatible/tools/filesystem.py | 163 +- .../openai_compatible/tools/glm_computer.py | 294 +++ .../openai_compatible/tools/qwen_computer.py | 266 +++ .../openai_compatible/tools/settings.py | 36 + hud/agents/openai_compatible/tools/types.py | 26 - hud/agents/resolver.py | 74 - hud/agents/tests/conftest.py | 310 +++- hud/agents/tests/test_base.py | 537 ------ hud/agents/tests/test_base_runtime.py | 221 --- hud/agents/tests/test_claude.py | 1605 ----------------- hud/agents/tests/test_gateway_resolution.py | 197 ++ hud/agents/tests/test_gemini.py | 1064 ----------- hud/agents/tests/test_hosted_tools.py | 299 ++- hud/agents/tests/test_openai.py | 824 --------- hud/agents/tests/test_openai_compatible.py | 300 --- .../tests/test_provider_claude_messages.py | 257 +++ .../tests/test_provider_computer_tools.py | 226 +++ .../test_provider_gemini_generate_content.py | 154 ++ .../tests/test_provider_native_tools.py | 147 ++ .../test_provider_openai_compatible_chat.py | 215 +++ .../tests/test_provider_openai_responses.py | 206 +++ .../tests/test_provider_tool_results.py | 174 ++ hud/agents/tests/test_resolver.py | 276 --- hud/agents/tests/test_run_eval.py | 269 --- hud/agents/tests/test_shared_eval_boundary.py | 239 +++ hud/agents/tests/test_shared_run_loop.py | 295 +++ hud/agents/tests/test_shared_tool_registry.py | 176 ++ hud/agents/tools/__init__.py | 24 +- hud/agents/tools/base.py | 240 ++- hud/agents/tools/capabilities.py | 186 +- hud/agents/tools/computer.py | 104 ++ hud/agents/tools/hosted.py | 47 +- hud/agents/tools/registry.py | 57 - hud/agents/types.py | 48 +- hud/cli/rl.py | 12 +- hud/cli/tests/test_eval.py | 1 + hud/cli/utils/version_check.py | 2 +- hud/datasets/runner.py | 4 +- hud/datasets/utils.py | 4 +- hud/environment/environment.py | 2 +- hud/environment/scenarios.py | 38 +- hud/environment/tests/test_environment.py | 4 +- hud/environment/tests/test_scenarios.py | 35 + hud/eval/__init__.py | 4 +- hud/eval/context.py | 112 +- hud/eval/manager.py | 6 +- hud/eval/task.py | 4 +- hud/services/chat.py | 2 +- .../public_api/test_v5_legacy_aliases.py | 6 - .../public_api/test_v5_surface_imports.py | 10 +- .../public_api/test_v5_workflow_contracts.py | 2 +- hud/tests/test_datasets_extended.py | 5 +- hud/tests/test_types.py | 41 +- hud/tools/agent.py | 2 +- hud/tools/computer/base.py | 4 +- hud/tools/computer/settings.py | 5 - hud/tools/tests/test_agent_tool.py | 292 +-- hud/tools/tests/test_coding_apply_patch.py | 64 +- hud/tools/tests/test_computer.py | 1 + hud/types.py | 96 +- hud/utils/hud_console.py | 75 - pyproject.toml | 3 +- 102 files changed, 6354 insertions(+), 10375 deletions(-) create mode 100644 hud/agents/gemini/settings.py delete mode 100644 hud/agents/misc/response_agent.py create mode 100644 hud/agents/misc/response_automation.py create mode 100644 hud/agents/openai_compatible/tools/base.py delete mode 100644 hud/agents/openai_compatible/tools/computer.py create mode 100644 hud/agents/openai_compatible/tools/glm_computer.py create mode 100644 hud/agents/openai_compatible/tools/qwen_computer.py create mode 100644 hud/agents/openai_compatible/tools/settings.py delete mode 100644 hud/agents/openai_compatible/tools/types.py delete mode 100644 hud/agents/resolver.py delete mode 100644 hud/agents/tests/test_base.py delete mode 100644 hud/agents/tests/test_base_runtime.py delete mode 100644 hud/agents/tests/test_claude.py create mode 100644 hud/agents/tests/test_gateway_resolution.py delete mode 100644 hud/agents/tests/test_gemini.py delete mode 100644 hud/agents/tests/test_openai.py delete mode 100644 hud/agents/tests/test_openai_compatible.py create mode 100644 hud/agents/tests/test_provider_claude_messages.py create mode 100644 hud/agents/tests/test_provider_computer_tools.py create mode 100644 hud/agents/tests/test_provider_gemini_generate_content.py create mode 100644 hud/agents/tests/test_provider_native_tools.py create mode 100644 hud/agents/tests/test_provider_openai_compatible_chat.py create mode 100644 hud/agents/tests/test_provider_openai_responses.py create mode 100644 hud/agents/tests/test_provider_tool_results.py delete mode 100644 hud/agents/tests/test_resolver.py delete mode 100644 hud/agents/tests/test_run_eval.py create mode 100644 hud/agents/tests/test_shared_eval_boundary.py create mode 100644 hud/agents/tests/test_shared_run_loop.py create mode 100644 hud/agents/tests/test_shared_tool_registry.py create mode 100644 hud/agents/tools/computer.py delete mode 100644 hud/agents/tools/registry.py diff --git a/docs/changelog.mdx b/docs/changelog.mdx index e99ec3366..7ac4fb8a8 100644 --- a/docs/changelog.mdx +++ b/docs/changelog.mdx @@ -25,7 +25,7 @@ description: "Product updates and release notes for HUD SDK and Platform." - **`hud sync env`** — sync local environment configs with collision detection (replaces `hud link`). - **`hud eval` accepts Python files** — run evaluations directly from `.py` files and directories containing `Task` objects. - **Chat class** — manage multi-turn agent conversations from a single SDK abstraction. -- **GPT-5 support** — `ResponseAgent` defaults to `gpt-5`, with ToolSearch tool support. +- **GPT-5 support** — auto-response classification defaults to `gpt-5`, with ToolSearch tool support. - **Citations** — citation support for Claude, Gemini, and OpenAI responses in chat and agent traces. ### Platform diff --git a/docs/reference/agents.mdx b/docs/reference/agents.mdx index f21c276f9..8ca828b2d 100644 --- a/docs/reference/agents.mdx +++ b/docs/reference/agents.mdx @@ -42,7 +42,7 @@ Abstract base class for all MCP-enabled agents. Handles the agent loop, MCP clie |-----------|------|-------------|---------| | `mcp_client` | `AgentMCPClient` | MCP client for server connections | `None` | | `auto_trace` | `bool` | Enable automatic tracing spans | `True` | -| `auto_respond` | `bool` | Use ResponseAgent to decide when to stop/continue | `False` | +| `auto_respond` | `bool` | Use response automation to decide when to stop/continue | `False` | | `verbose` | `bool` | Verbose console logs for development | `False` | **Base Config** (shared by all agents): @@ -63,9 +63,6 @@ async def run(ctx: EvalContext, max_steps: int = 10) -> Trace: async def call_tools(tool_call: MCPToolCall | list[MCPToolCall]) -> list[MCPToolResult]: """Execute tool calls through MCP client.""" - -def get_available_tools() -> list[types.Tool]: - """Get filtered list of available tools.""" ``` ## Pre-built Agents @@ -251,7 +248,7 @@ result = await agent.run(task, max_steps=20) ### Auto-Respond Mode -When `auto_respond=True`, the agent uses a ResponseAgent to decide whether to continue or stop after each model response: +When `auto_respond=True`, the agent uses response automation to decide whether to continue or stop after each model response: ```python agent = ClaudeAgent.create( diff --git a/docs/reference/cli/eval.mdx b/docs/reference/cli/eval.mdx index 13903c044..d79f2596b 100644 --- a/docs/reference/cli/eval.mdx +++ b/docs/reference/cli/eval.mdx @@ -79,7 +79,7 @@ hud eval [SOURCE] [AGENT] [OPTIONS] - Use ResponseAgent to decide when to stop/continue. Default: True for `--full`. + Use response automation to decide when to stop/continue. Default: True for `--full`. ### Taskset Association diff --git a/docs/reference/types.mdx b/docs/reference/types.mdx index f3d8e091d..bbd5bfad8 100644 --- a/docs/reference/types.mdx +++ b/docs/reference/types.mdx @@ -111,12 +111,12 @@ print(result.reward, result.done) | `trace` | `list[TraceStep]` | Execution trace steps | | `messages` | `list[Any]` | Final conversation state | -## InferenceResult +## AgentResponse Returned by agent `get_response()` methods. Represents the result of a single LLM inference call. ```python -from hud.types import InferenceResult +from hud.types import AgentResponse ``` | Field | Type | Description | @@ -129,8 +129,6 @@ from hud.types import InferenceResult | `info` | `dict[str, Any]` | Provider-specific metadata | | `isError` | `bool` | Error flag | -> **Note:** `AgentResponse` is available as a backwards-compatible alias for `InferenceResult`. - ## AgentType Enum of supported agent types. diff --git a/hud/agents/__init__.py b/hud/agents/__init__.py index 27ca0b327..b17f59bb5 100644 --- a/hud/agents/__init__.py +++ b/hud/agents/__init__.py @@ -1,95 +1,15 @@ from __future__ import annotations -import sys -from types import ModuleType -from typing import Any - -from .base import CategorizedTools, MCPAgent +from .base import MCPAgent from .claude import ClaudeAgent +from .gateway import create_agent from .openai import OpenAIAgent from .openai_compatible import OpenAIChatAgent __all__ = [ - "CategorizedTools", "ClaudeAgent", "MCPAgent", "OpenAIAgent", "OpenAIChatAgent", "create_agent", ] - - -def _install_openai_chat_compat_module() -> None: - module_name = f"{__name__}.openai_chat" - if module_name in sys.modules: - return - - module: Any = ModuleType(module_name, "Compatibility module for OpenAIChatAgent.") - module.OpenAIChatAgent = OpenAIChatAgent - module.__all__ = ["OpenAIChatAgent"] - sys.modules[module_name] = module - - -_install_openai_chat_compat_module() - - -def create_agent(model: str, **kwargs: Any) -> MCPAgent: - """Create an agent for a gateway model. - - This routes ALL requests through the HUD gateway. For direct API access - (using your own API keys), use the agent classes directly. - - Args: - model: Model name (e.g., "gpt-5.4", "claude-sonnet-4-6"). - **kwargs: Additional params passed to agent.create(). - - Returns: - Configured MCPAgent instance with gateway routing. - - Example: - ```python - # Gateway routing (recommended) - agent = create_agent("gpt-5.4") - agent = create_agent("claude-sonnet-4-6", temperature=0.7) - - # Direct API access (use agent classes) - from hud.agents.claude import ClaudeAgent - - agent = ClaudeAgent.create(model="claude-sonnet-4-6") - ``` - """ - from hud.agents.gateway import build_gateway_client - from hud.agents.resolver import resolve_cls - - # Resolve class and gateway info - agent_cls, gateway_info = resolve_cls(model) - - # Get model name from gateway info or use input - model_id = model - if gateway_info: - model_id = gateway_info.get("model_name") or model - - # Determine provider: from gateway info, or infer from agent class - if gateway_info: - provider = gateway_info["provider"]["name"] - else: - provider = "openai" - if agent_cls.__name__ == "ClaudeAgent": - provider = "anthropic" - elif agent_cls.__name__ == "GeminiAgent": - provider = "gemini" - - client = build_gateway_client(provider) - - # Set up kwargs - kwargs.setdefault("model", model_id) - - # Use correct client key based on agent type - if agent_cls == OpenAIChatAgent: - kwargs.setdefault("openai_client", client) - else: - # Claude and other agents use model_client and validate_api_key - kwargs.setdefault("model_client", client) - kwargs.setdefault("validate_api_key", False) - - return agent_cls.create(**kwargs) diff --git a/hud/agents/base.py b/hud/agents/base.py index 9e2581c1f..75fb7345f 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -3,543 +3,188 @@ from __future__ import annotations import asyncio -import json import logging -import re from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast -import mcp.types as types - -from hud.tools.types import Citation -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult, Trace -from hud.utils.hud_console import HUDConsole - -from .types import BaseCreateParams +from hud.agents.misc import auto_respond +from hud.types import AgentResponse, Trace if TYPE_CHECKING: - from hud.environment import Environment - from hud.eval.context import EvalContext + import mcp.types as types + from hud.agents.tools import AgentTools + from hud.agents.tools.base import CallTool, ToolClient + from hud.agents.types import AgentConfig +ProviderMessageT = TypeVar("ProviderMessageT") logger = logging.getLogger(__name__) -@dataclass -class CategorizedTools: - """Result of filtering tools for model-facing schemas.""" - - generic: list[types.Tool] = field(default_factory=list) - """MCP tools exposed through generic function calling.""" +@dataclass(frozen=True) +class AgentContext: + """Prompt messages plus optional MCP tool access for one agent run.""" - skipped: list[tuple[types.Tool, str]] = field(default_factory=list) - """Tools intentionally hidden from generic function calling.""" + messages: list[types.PromptMessage] + tool_client: ToolClient | None = None -class MCPAgent(ABC): +class MCPAgent(ABC, Generic[ProviderMessageT]): """ - Base class for MCP-enabled agents. - - Agents interact with MCP servers through an EvalContext: - - run(ctx): Main entry point - takes EvalContext from hud.eval() - - ctx.call_tool(): Used internally for all tool execution - - ctx.submit(): Called automatically with agent's final response - - Subclasses implement provider-specific formatting and response fetching - by overriding: `get_system_messages`, `get_response`, `format_blocks`, - and `format_tool_results`. - """ - - metadata: ClassVar[dict[str, Any] | None] = None - required_tools: ClassVar[list[str]] = [] # Tools that must be available - config_cls: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig - - @classmethod - @abstractmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for this agent. - - Subclasses must implement this to return their corresponding AgentType enum value. - This is used for provider-specific configuration and routing. - - Returns: - AgentType enum value for this agent - """ - raise NotImplementedError - - def categorize_tools(self, tools: list[types.Tool] | None = None) -> CategorizedTools: - """Return the MCP tools that should be exposed as generic function tools.""" - if tools is None: - tools = self.get_available_tools() - - return CategorizedTools(generic=list(tools)) + Base class for agents that interact with HUD MCP-backed environments. - def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> None: - if params is None: - import warnings + Agent instances are intended to be run-scoped: create a fresh agent for each + independent evaluation or task run. Provider implementations may keep + conversation IDs, continuation cursors, and prepared tool state on the + instance during a run. - warnings.warn( - f"Passing kwargs to {self.__class__.__name__}() is deprecated. " - f"Use {self.__class__.__name__}.create(...) instead.", - DeprecationWarning, - stacklevel=2, - ) - CreateParams = type( - f"{self.config_cls.__name__}CreateParams", - (BaseCreateParams, self.config_cls), - {"__module__": self.config_cls.__module__}, - ) - params = CreateParams(**kwargs) - - config_kwargs = { - k: getattr(params, k) for k in self.config_cls.model_fields if hasattr(params, k) - } - self.config = self.config_cls(**config_kwargs) + Agents interact with environments through per-run tools and tool handlers supplied + by the caller. - # Store execution context (EvalContext/Environment); agents use ctx.call_tool(). - self.ctx: EvalContext | Environment | None = params.ctx - - self.model_name: str = getattr(params, "model_name", "MCPAgent") - self.model: str = getattr(params, "model", None) or "unknown" - self.auto_respond = params.auto_respond + Subclasses implement provider-specific message formatting, response fetching, + and tool result rendering. + """ - self.console = HUDConsole(logger=logger) + def __init__(self, config: AgentConfig) -> None: + self.config = config - if params.verbose: - self.console.set_verbose(True) + self.model_name: str = self.config.model_name + self.model: str = self.config.model self.system_prompt = self.config.system_prompt - self._available_tools: list[types.Tool] | None = None - self._categorized_tools: CategorizedTools = CategorizedTools() - self._initialized: bool = False - - @classmethod - def create(cls, **kwargs: Any) -> MCPAgent: - """ - Factory method to create an agent with typed parameters. - """ - CreateParams = type( - f"{cls.config_cls.__name__}CreateParams", - (BaseCreateParams, cls.config_cls), - {"__module__": cls.config_cls.__module__}, - ) - return cls(params=CreateParams(**kwargs)) - - async def _initialize_from_ctx(self, ctx: EvalContext) -> None: - """Initialize agent from EvalContext - discovers tools and sets up state. - - The agent uses ctx.call_tool() directly - for tool execution (no EnvironmentClient wrapper needed). - """ - from hud.eval.context import EvalContext - - if not isinstance(ctx, EvalContext): - raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}") - - # Refresh tools from connections, then get filtered list for agent - await ctx.list_tools() - self._available_tools = ctx.as_tools() - - # Validate required tools are present - available_tool_names = {t.name for t in self._available_tools} - missing_tools = [tool for tool in self.required_tools if tool not in available_tool_names] - if missing_tools: - raise ValueError( - f"Required tools are missing: {missing_tools}. " - f"Available tools: {sorted(available_tool_names)}" - ) - - self._categorized_tools = self.categorize_tools() - - # Show tool discovery table (visible at INFO level) - self.console.format_tool_discovery( - tools=self._available_tools, - skipped=self._categorized_tools.skipped, - ) + self.enable_citations: bool = False - for tool, reason in self._categorized_tools.skipped: - logger.debug("Skipping tool %s: %s", tool.name, reason) + self.auto_respond: bool = config.auto_respond - # Call hook for subclass-specific initialization (e.g., tool format conversion) - self._on_tools_ready() - - self._initialized = True - - def _on_tools_ready(self) -> None: - """Hook called after tools are discovered and validated. - - Subclasses can override this to perform provider-specific setup, - such as converting MCP tools to the provider's format. + @classmethod + def create(cls, **kwargs: object) -> MCPAgent[ProviderMessageT]: + raise NotImplementedError(f"{cls.__name__}.create() must be implemented by subclasses") - Called by _initialize_from_ctx() after _available_tools is populated. - """ - return # Default no-op - subclasses override for provider-specific setup + @cached_property + @abstractmethod + def tools(self) -> AgentTools[Any, Any]: + """Provider-specific tool container used by the shared run loop.""" + raise NotImplementedError async def run( self, - ctx: EvalContext, + ctx: AgentContext, *, max_steps: int = 10, ) -> Trace: """ - Run the agent on the given evaluation context. - - The agent uses ctx.prompt as the task and ctx.call_tool() for tool execution. - Automatically calls ctx.submit() with the final answer. + Run the agent loop with prepared messages and optional tools. Args: - ctx: EvalContext from hud.eval() - contains prompt and tools + ctx: Prompt messages and optional environment client max_steps: Maximum number of agent steps (-1 for infinite) - Returns: - Trace with done, content, isError fields - - Example: - ```python - async with hud.eval(task) as ctx: - agent = ClaudeAgent.create() - await agent.run(ctx) - # ctx.reward is set by the scenario's evaluate phase - ``` - """ - from hud.eval.context import EvalContext - - if not isinstance(ctx, EvalContext): - raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}") - - if not ctx.prompt: - if ctx.has_scenario: - # Scenario was specified but prompt is still empty - # (e.g., scenario returned empty string, or edge case not caught in scenarios.py) - scenario = ctx._task.scenario if ctx._task else "unknown" - raise ValueError( - f"ctx.prompt is not set.\n\n" - f"Scenario '{scenario}' was specified but returned an empty prompt.\n" - f"Check that the scenario's setup function returns a non-empty string." - ) - else: - # No scenario specified at all - raise ValueError( - "ctx.prompt is not set.\n\n" - "No scenario was specified in your task file.\n" - "Add a 'scenario' field to your task so scenario setup can produce a prompt." - ) - - # Store context for tool calls - self.ctx = ctx - - # Initialize tools from context - if not self._initialized: - await self._initialize_from_ctx(ctx) - - try: - # Build initial context - conversation: list[dict[str, str]] | None = getattr(ctx, "conversation", None) - - if conversation: - # Multi-turn: build alternating role messages - initial_messages = await self._build_conversation_messages(conversation) - else: - # Single-turn: single user message from prompt - initial_messages = await self.format_message(ctx.prompt) - - result = await self._run_context(initial_messages, max_steps=max_steps) - - # Propagate error state to context for platform visibility - if result.isError and hasattr(ctx, "error"): - error_msg = result.info.get("error") if result.info else result.content - ctx.error = Exception(str(error_msg)) if error_msg else Exception("Agent error") - - # Submit final answer to context (only if scenario is running) - if result.content and ctx.has_scenario: - if result.citations: - await ctx.submit( - { - "content": result.content, - "citations": result.citations, - } - ) - else: - await ctx.submit(result.content) - - return result - - except Exception as e: - logger.exception("Error while running agent:") - # Propagate error to context for platform visibility - if hasattr(ctx, "error"): - ctx.error = e - return Trace( - reward=0.0, - done=True, - content=f"Agent failed with error: {e}", - isError=True, - info={"error": str(e)}, - ) - finally: - # Cleanup auto-created resources - await self._cleanup() - - def _map_role(self, role: str) -> str: - """Map a canonical role name to the provider-specific role. - - Override in subclasses where the provider uses different role names. - Default passes through (works for OpenAI and Claude which use "assistant"). - """ - return role - - async def _build_conversation_messages(self, conversation: list[dict[str, str]]) -> list[Any]: - """Build provider-formatted messages from a conversation history.""" - result: list[Any] = [] - for msg in conversation: - role = self._map_role(msg.get("role", "user")) - content = msg.get("content", "") - formatted = await self.format_message(content) - for fm in formatted: - if isinstance(fm, dict): - fm["role"] = role - elif hasattr(fm, "role"): - fm.role = role # type: ignore[attr-defined] - result.extend(formatted) - return result - - async def _run_context(self, initial_messages: list[Any], *, max_steps: int = 10) -> Trace: - """ - Run the agent with pre-built messages. This is the core agent loop. - - Args: - initial_messages: Provider-formatted messages (from format_message or conversation) - max_steps: Maximum number of steps (-1 for infinite) - Returns: Trace with reward, done, content fields and trace steps """ - final_response: InferenceResult | None = None - error = None - - messages: list[Any] = [] + tool_handler: CallTool | None = None + if ctx.tool_client is not None: + self.tools.prepare( + model=self.model, + tools=ctx.tool_client.tools, + hosted_tools=self.config.hosted_tools, + tool_metadata=ctx.tool_client.tool_metadata, + ) + tool_handler = ctx.tool_client.tool_handler + messages: list[ProviderMessageT] = [] try: - messages = await self.get_system_messages() - messages.extend(initial_messages) - self.console.debug(f"Messages: {messages}") + messages = await self.format_messages(ctx.messages) + logger.debug("Messages: %s", messages) step_count = 0 while max_steps == -1 or step_count < max_steps: step_count += 1 if max_steps == -1: - self.console.debug(f"Step {step_count} (unlimited)") + logger.debug("Step %s (unlimited)", step_count) else: - self.console.debug(f"Step {step_count}/{max_steps}") + logger.debug("Step %s/%s", step_count, max_steps) try: # 1. Get model response response = await self.get_response(messages) - self.console.debug(f"Agent:\n{response}") + logger.debug("Agent:\n%s", response) - # Check if we should stop if response.done or not response.tool_calls: - # Use auto_respond to decide whether to stop - decision: Literal["STOP", "CONTINUE"] = "STOP" - if self.auto_respond and response.content: - try: - from hud.agents.misc import ResponseAgent - - response_agent = ResponseAgent() - decision = await response_agent.determine_response(response.content) - except Exception as e: - self.console.warning_log(f"Auto-respond failed: {e}") - if decision == "STOP": - if ( - getattr(self.ctx, "scenario_enable_citations", False) - and not response.citations - ): - recovered = self._recover_citations_from_content(response) - if recovered: - self.console.info_log( - "Recovered citations from JSON answer payload" - ) - else: - self.console.warning_log( - "Citations required by scenario but missing in final response" # noqa: E501 - ) - self.console.debug("Stopping execution") - final_response = response - break - else: - self.console.debug("Continuing execution") - messages.extend(await self.format_message(decision)) + if follow_up := await auto_respond( + response.content, + enabled=self.auto_respond, + ): + logger.debug("Continuing execution") + messages.extend(await self.format_messages([follow_up])) continue - # 2. Execute tools - tool_calls = response.tool_calls - tool_results = await self.call_tools(tool_calls) - - # 3. Format tool results and add to messages - tool_messages = await self.format_tool_results(tool_calls, tool_results) - messages.extend(tool_messages) - - if logger.isEnabledFor(logging.INFO): - self.console.format_step( - step=step_count, - max_steps=max_steps, - tool_calls=tool_calls, - tool_results=tool_results, + logger.debug("Stopping execution") + return Trace( + done=True, + messages=messages, + content=response.content, + isError=response.isError, + citations=response.citations, ) + # 2. Execute tools + tool_messages = await self.tools.execute( + tool_handler, + response.tool_calls, + ) + + messages.extend(cast("list[ProviderMessageT]", tool_messages)) + except Exception as e: - self.console.error_log(f"Step failed: {e}") - error = str(e) - break + logger.exception("Step failed") + return Trace( + done=True, + messages=messages, + content=str(e), + isError=True, + info={"error": str(e)}, + ) except KeyboardInterrupt: - self.console.warning_log("Agent execution interrupted by user") - error = "Interrupted by user" + logger.warning("Agent execution interrupted by user") + return Trace( + done=True, + messages=messages, + content="Interrupted by user", + isError=True, + info={"error": "Interrupted by user"}, + ) except asyncio.CancelledError: - self.console.warning_log("Agent execution cancelled") - error = "Cancelled" + logger.warning("Agent execution cancelled") + return Trace( + done=True, + messages=messages, + content="Cancelled", + isError=True, + info={"error": "Cancelled"}, + ) except Exception as e: - self.console.error_log(f"Unexpected error: {e}") - error = str(e) - - # Build result - if error is not None or ( - final_response and hasattr(final_response, "isError") and final_response.isError - ): - is_error = True - else: - is_error = False - - # Use ctx.reward if already set (e.g., from scenario evaluate), otherwise 0.0 - reward = 0.0 - if self.ctx is not None: - ctx_reward = getattr(self.ctx, "reward", None) - if ctx_reward is not None: - reward = ctx_reward + logger.exception("Unexpected error") + return Trace( + done=True, + messages=messages, + content=str(e), + isError=True, + info={"error": str(e)}, + ) return Trace( - reward=reward, done=True, messages=messages, - content=final_response.content if final_response else error, - isError=is_error, - citations=final_response.citations if final_response else [], - info={"error": error} if error else {}, ) - def _recover_citations_from_content(self, response: InferenceResult) -> bool: - """Try to extract citations from model content when native citations are missing. - - Handles two cases: raw JSON content and fenced ```json blocks. - """ - raw = response.content or "" - if not raw: - return False - - # Try raw content first, then try extracting from fenced block. - for text in dict.fromkeys([raw, self._extract_fenced_json(raw) or ""]): - if not text: - continue - try: - parsed = json.loads(text) - except (json.JSONDecodeError, TypeError): - continue - if not isinstance(parsed, dict): - continue - - raw_citations = parsed.get("citations") - if not isinstance(raw_citations, list) or not raw_citations: - continue - - normalized: list[Citation] = [ - c - for cit in raw_citations - if isinstance(cit, dict) and (c := self._normalize_citation(cit)) is not None - ] - if not normalized: - continue - - content = parsed.get("content") - if isinstance(content, str) and content.strip(): - response.content = content - response.citations = [c.model_dump(exclude={"provider_data"}) for c in normalized] - return True - - return False - - @staticmethod - def _extract_fenced_json(value: str) -> str | None: - """Extract JSON content from a fenced code block.""" - match = re.search(r"```(?:json)?\s*\n(.*?)```", value, re.DOTALL) - return match.group(1).strip() if match else None - - @staticmethod - def _normalize_citation(cit: dict[str, Any]) -> Citation | None: - """Normalize a citation dict to canonical Citation shape. - - Maps common key aliases to canonical names and validates via Citation. - Returns None only if construction fails (e.g. extra-forbid violation). - """ - source = cit.get("source") or cit.get("document") or "" - try: - return Citation( - type=cit.get("type", "document_citation"), - text=cit.get("text") or cit.get("cited_text", ""), - source=str(source), - title=cit.get("title") or cit.get("document_title"), - start_index=cit.get("start_index", cit.get("start_char_index")), - end_index=cit.get("end_index", cit.get("end_char_index")), - ) - except Exception: - return None - - async def call_tools( - self, tool_call: MCPToolCall | list[MCPToolCall] | None = None - ) -> list[MCPToolResult]: - """ - Call tools through the bound EvalContext. - - Args: - tool_call: MCPToolCall or list of MCPToolCall - - Returns: - List of MCPToolResult - """ - if tool_call is None: - return [] - - if isinstance(tool_call, MCPToolCall): - tool_call = [tool_call] - - if self.ctx is None: - raise ValueError("Agent not bound to context - call run(ctx) first") - - results: list[MCPToolResult] = [] - for tc in tool_call: - try: - self.console.debug(f"Calling tool: {tc}") - result = await self.ctx.call_tool(tc) - results.append(MCPToolResult(content=result.content, isError=result.isError)) - except TimeoutError as e: - self.console.error_log(f"Tool execution timed out: {e}") - raise - except Exception as e: - self.console.error_log(f"Tool execution failed: {e}") - results.append(_format_error_result(str(e))) - return results - @abstractmethod - async def get_system_messages(self) -> list[types.ContentBlock]: - """ - Get the system prompt. - """ - raise NotImplementedError - - @abstractmethod - async def get_response(self, messages: list[Any]) -> InferenceResult: + async def get_response(self, messages: list[ProviderMessageT]) -> AgentResponse: """ Get response from the model including any tool calls. @@ -553,114 +198,6 @@ async def get_response(self, messages: list[Any]) -> InferenceResult: raise NotImplementedError @abstractmethod - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - """ - Format a list of content blocks into a list of messages. - """ + async def format_messages(self, messages: list[types.PromptMessage]) -> list[ProviderMessageT]: + """Format MCP prompt messages into provider messages.""" raise NotImplementedError - - @abstractmethod - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[Any]: - """ - Format tool results into messages for the model. - - Args: - tool_calls: List of MCPToolCall objects that were executed - tool_results: List of MCPToolResult objects from tool execution - - Returns: - List of formatted messages to append to conversation - """ - raise NotImplementedError - - async def format_message( - self, - message: str - | list[str] - | types.ContentBlock - | list[types.ContentBlock] - | list[str | types.ContentBlock], - ) -> list[Any]: # maybe type messages as list[types.ContentBlock] - """ - Convencience function. - - Format a single content message into a list of messages for the model. - """ - blocks: list[types.ContentBlock] = [] - if not isinstance(message, list): - message = [message] - - for m in message: - if isinstance(m, str): - blocks.append(types.TextContent(text=m, type="text")) - elif isinstance(m, types.ContentBlock): - blocks.append(m) - else: - raise ValueError(f"Invalid message type: {type(m)}") - - return await self.format_blocks(blocks) - - def get_available_tools(self) -> list[types.Tool]: - """Get list of available MCP tools for LLM use (excludes lifecycle tools).""" - if self._available_tools is None: - raise RuntimeError( - "Tools have not been initialized. Call initialize() before accessing available tools." # noqa: E501 - ) - return self._available_tools - - def get_tool_schemas(self) -> list[dict]: - """Get tool schemas in a format suitable for the model. - - Uses categorized tools so that skipped tools are excluded from schemas - automatically. Falls back to get_available_tools() if called before - categorization. - """ - if self._initialized: - tools = list(self._categorized_tools.generic) - else: - tools = self.get_available_tools() - - schemas = [] - for tool in tools: - schema = { - "name": tool.name, - "description": tool.description, - } - if tool.inputSchema: - schema["parameters"] = tool.inputSchema - schemas.append(schema) - return schemas - - async def _filter_messages( - self, - message_list: list[types.ContentBlock], - include_types: list[ - Literal["text", "image", "audio", "resource_link", "embedded_resource"] - ], - ) -> list[types.ContentBlock]: - """ - Filter a list of messages and return only the messages of the given types. - - Args: - message_list: The list of messages to filter - include_types: List of types to include (None = all types) - - Returns: - List of messages in provider-specific format - """ - return [message for message in message_list if message.type in include_types] - - async def _cleanup(self) -> None: - """Cleanup resources.""" - # Clear context reference - self.ctx = None - - -def _format_error_result(error_message: str) -> MCPToolResult: - return MCPToolResult(content=text_to_blocks(error_message), isError=True) - - -def text_to_blocks(text: str) -> list[types.ContentBlock]: - return [types.TextContent(text=text, type="text")] diff --git a/hud/agents/claude/__init__.py b/hud/agents/claude/__init__.py index ce90d2178..5d1c41a60 100644 --- a/hud/agents/claude/__init__.py +++ b/hud/agents/claude/__init__.py @@ -6,11 +6,6 @@ AsyncAnthropic, AsyncAnthropicBedrock, ClaudeAgent, - base64_to_content_block, - document_to_content_block, - text_document_block, - text_to_content_block, - tool_use_content_block, ) from .tools import ClaudeToolSearchTool, ClaudeWebFetchTool, ClaudeWebSearchTool @@ -21,9 +16,4 @@ "ClaudeToolSearchTool", "ClaudeWebFetchTool", "ClaudeWebSearchTool", - "base64_to_content_block", - "document_to_content_block", - "text_document_block", - "text_to_content_block", - "tool_use_content_block", ] diff --git a/hud/agents/claude/agent.py b/hud/agents/claude/agent.py index 71698d9d1..1d5274de4 100644 --- a/hud/agents/claude/agent.py +++ b/hud/agents/claude/agent.py @@ -5,50 +5,44 @@ import copy import json import logging -from inspect import cleandoc -from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from functools import cached_property +from typing import TYPE_CHECKING, Literal, cast -import mcp.types as types +import mcp.types as mcp_types from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, Omit from anthropic.types import CacheControlEphemeralParam from anthropic.types.beta import ( BetaBase64ImageSourceParam, BetaBase64PDFSourceParam, - BetaContentBlockParam, BetaImageBlockParam, + BetaMessage, BetaMessageParam, - BetaPlainTextSourceParam, BetaRequestDocumentBlockParam, + BetaTextBlock, BetaTextBlockParam, - BetaToolParam, - BetaToolResultBlockParam, + BetaToolChoiceAutoParam, BetaToolUnionParam, ) +from hud.agents import gateway from hud.agents.base import MCPAgent -from hud.agents.tools import ( - EnvironmentCapability, - call_agent_tools, - capabilities_metadata_from_context, - discover_environment_capabilities, - select_hosted_tools, -) -from hud.agents.types import ClaudeConfig, ClaudeCreateParams +from hud.agents.types import ClaudeConfig from hud.settings import settings -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult -from hud.utils.hud_console import HUDConsole +from hud.tools.types import Citation +from hud.types import AgentResponse, MCPToolCall from hud.utils.types import with_signature -from .tools import ClaudeHostedTool, ClaudeTool, ClaudeToolSearchTool, claude_tools -from .tools.settings import claude_tool_settings +from .tools import ClaudeAgentTools if TYPE_CHECKING: - from collections.abc import Sequence + import mcp.types as types + from anthropic.types.beta import BetaTextCitation logger = logging.getLogger(__name__) +ClaudeImageMediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"] -class ClaudeAgent(MCPAgent): +class ClaudeAgent(MCPAgent[BetaMessageParam]): """ Claude agent that uses MCP servers for tool execution. @@ -56,33 +50,21 @@ class ClaudeAgent(MCPAgent): tools through MCP servers instead of direct implementation. """ - metadata: ClassVar[dict[str, Any] | None] = { - "display_width": claude_tool_settings.COMPUTER_WIDTH, - "display_height": claude_tool_settings.COMPUTER_HEIGHT, - } - config_cls: ClassVar[type[BaseAgentConfig]] = ClaudeConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for Claude.""" - return AgentType.CLAUDE - - @with_signature(ClaudeCreateParams) + @with_signature(ClaudeConfig) @classmethod - def create(cls, **kwargs: Any) -> ClaudeAgent: # pyright: ignore[reportIncompatibleMethodOverride] - return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value] + def create(cls, **kwargs: object) -> ClaudeAgent: # pyright: ignore[reportIncompatibleMethodOverride] + return cls(ClaudeConfig.model_validate(kwargs)) - def __init__(self, params: ClaudeCreateParams | None = None, **kwargs: Any) -> None: - super().__init__(params, **kwargs) + def __init__(self, config: ClaudeConfig | None = None) -> None: + config = config or ClaudeConfig() + super().__init__(config) self.config: ClaudeConfig model_client = self.config.model_client if model_client is None: # Default to HUD gateway when HUD_API_KEY is available if settings.api_key: - from hud.agents.gateway import build_gateway_client - - model_client = build_gateway_client("anthropic") + model_client = gateway.build_gateway_client("anthropic") elif settings.anthropic_api_key: model_client = AsyncAnthropic(api_key=settings.anthropic_api_key) else: @@ -95,511 +77,262 @@ def __init__(self, params: ClaudeCreateParams | None = None, **kwargs: Any) -> N " access" ) - self.anthropic_client: AsyncAnthropic | AsyncAnthropicBedrock = model_client - self.max_tokens = self.config.max_tokens - self.use_computer_beta = self.config.use_computer_beta - self.hud_console = HUDConsole(logger=logger) - - # these will be initialized in _convert_tools_for_claude - self.has_computer_tool = False - self.tool_mapping: dict[str, str] = {} - self.claude_tools: list[BetaToolUnionParam] = [] - self._claude_native_tools: dict[str, ClaudeTool] = {} - self._environment_capabilities: dict[str, EnvironmentCapability] = {} - self._required_betas: set[str] = set() - self._tool_search_threshold: int | None = None - - def _on_tools_ready(self) -> None: - """Build Claude-specific tool mappings after tools are discovered.""" - self._convert_tools_for_claude() - - def _discover_environment_capabilities( - self, tools: list[types.Tool] - ) -> dict[str, EnvironmentCapability]: - return discover_environment_capabilities( - tools, - env_metadata=capabilities_metadata_from_context(self.ctx), - name_fallbacks=claude_tools.name_fallbacks, + self.anthropic_client: AsyncAnthropic | AsyncAnthropicBedrock = cast( + "AsyncAnthropic | AsyncAnthropicBedrock", model_client ) + self.max_tokens = self.config.max_tokens - async def get_system_messages(self) -> list[types.ContentBlock]: - """No system messages for Claude because applied in get_response""" - return [] - - def _result_from_response_blocks(self, response_blocks: list[Any]) -> InferenceResult: - """Extract text/tool calls/citations from Anthropic response blocks.""" - result = InferenceResult(content="", tool_calls=[], done=True) - text_content = "" - thinking_content = "" - citations: list[dict[str, Any]] = [] - - for block in response_blocks: - block_type = getattr(block, "type", None) - if block_type == "tool_use": - block_input = getattr(block, "input", {}) - mcp_name = self.tool_mapping.get( - getattr(block, "name", ""), - getattr(block, "name", ""), - ) - arguments = block_input if isinstance(block_input, dict) else block_input.__dict__ - tool_call = MCPToolCall( - id=getattr(block, "id", ""), - name=mcp_name, - arguments=arguments, - ) - result.tool_calls.append(tool_call) - result.done = False - elif block_type == "text": - text = getattr(block, "text", "") or "" - text_content += text - block_citations = getattr(block, "citations", None) or [] - for cit in block_citations: - cit_dict = { - "type": "document_citation", - "text": getattr(cit, "cited_text", "") or "", - "source": ( - str(idx) - if (idx := getattr(cit, "document_index", None)) is not None - else getattr(cit, "document_title", "") or "" - ), - "title": getattr(cit, "document_title", None), - "start_index": getattr(cit, "start_char_index", None), - "end_index": getattr(cit, "end_char_index", None), - } - normalized = self._normalize_citation(cit_dict) - if normalized is not None: - citations.append(normalized.model_dump(exclude={"provider_data"})) - elif block_type == "thinking": - thinking = getattr(block, "thinking", "") or "" - if thinking: - if thinking_content: - thinking_content += "\n" - thinking_content += thinking - - result.content = text_content - result.citations = citations - if thinking_content: - result.reasoning = thinking_content - return result - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[BetaMessageParam]: - """Format messages for Claude.""" - # Convert MCP content types to Anthropic content types - anthropic_blocks: list[BetaContentBlockParam] = [] - - for block in blocks: - if isinstance(block, types.TextContent): - # Only include fields that Anthropic expects - anthropic_blocks.append( - BetaTextBlockParam( - type="text", - text=block.text, - ) - ) - elif isinstance(block, types.ImageContent): - # Convert MCP ImageContent to Anthropic format - anthropic_blocks.append( - BetaImageBlockParam( + @cached_property + def tools(self) -> ClaudeAgentTools: + return ClaudeAgentTools() + + async def format_messages(self, messages: list[types.PromptMessage]) -> list[BetaMessageParam]: + """Format MCP prompt messages for Claude.""" + formatted: list[BetaMessageParam] = [] + for message in messages: + match message.content: + case mcp_types.TextContent(): + content = BetaTextBlockParam(type="text", text=message.content.text) + case mcp_types.ImageContent(): + content = BetaImageBlockParam( type="image", source=BetaBase64ImageSourceParam( type="base64", - media_type=cast( - "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']", - block.mimeType, - ), - data=block.data, + media_type=cast("ClaudeImageMediaType", message.content.mimeType), + data=message.content.data, + ), + ) + case mcp_types.EmbeddedResource( + resource=mcp_types.BlobResourceContents(mimeType="application/pdf") as resource + ): + content = BetaRequestDocumentBlockParam( + type="document", + source=BetaBase64PDFSourceParam( + type="base64", + media_type="application/pdf", + data=resource.blob, ), ) + case _: + raise ValueError(f"Unknown content block type: {type(message.content)}") + formatted.append( + BetaMessageParam( + role=message.role, + content=[content], ) - else: - raise ValueError(f"Unknown content block type: {type(block)}") - - return [BetaMessageParam(role="user", content=anthropic_blocks)] - - @staticmethod - def _extract_invalid_tool_json(exc: Exception) -> str | None: - """Extract malformed tool JSON payload from Anthropic stream errors. - - Returns None when the exception is unrelated to tool JSON parsing. - """ - message = str(exc) - parse_error_prefix = "Unable to parse tool parameter JSON from model." - if parse_error_prefix not in message: - return None - - marker = "JSON: " - marker_index = message.find(marker) - if marker_index == -1: - return "" - - return message[marker_index + len(marker) :].strip() - - @staticmethod - def _build_invalid_tool_json_retry_message(invalid_json: str) -> BetaMessageParam: - """Build a user message prompting the model to re-emit valid tool JSON.""" - wrapped = json.dumps({"INVALID_JSON": invalid_json}, ensure_ascii=True) - retry_text = ( - "Your previous tool-call arguments were invalid JSON and could not be parsed.\n" - "Retry the same intended tool call once with valid JSON arguments only.\n" - "Ensure all strings are quoted and all arrays/objects are valid JSON.\n" - f"Malformed payload (wrapped): {wrapped}" - ) - return BetaMessageParam( - role="user", - content=[text_to_content_block(retry_text)], - ) + ) + return formatted - async def get_response(self, messages: list[BetaMessageParam]) -> InferenceResult: + async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: """Get response from Claude including any tool calls.""" - messages_cached = self._add_prompt_caching(messages) # Betas are collected during provider tool conversion. # Only pass betas when non-empty; an empty list can produce an empty # anthropic-beta header which the API rejects. - betas: list[str] | Omit = list(self._required_betas) if self._required_betas else Omit() + betas: list[str] | Omit = ( + list(self.tools.required_betas) if self.tools.required_betas else Omit() + ) + tool_choice = BetaToolChoiceAutoParam(type="auto", disable_parallel_tool_use=True) - effective_tools: list[BetaToolUnionParam] = list(self.claude_tools) - if self._tool_search_threshold is not None: - generic_count = sum( - 1 for t in effective_tools if isinstance(t, dict) and "input_schema" in t - ) - if generic_count > self._tool_search_threshold: + effective_tools: list[BetaToolUnionParam] = list(self.tools.params) + if self.tools.tool_search_threshold is not None: + generic_count = sum(1 for t in effective_tools if "input_schema" in t) + if generic_count > self.tools.tool_search_threshold: logger.debug( "tool_search: %d generic tools > threshold %d, applying defer_loading", generic_count, - self._tool_search_threshold, + self.tools.tool_search_threshold, ) effective_tools = [ - {**t, "defer_loading": True} - if isinstance(t, dict) and "input_schema" in t - else t + {**t, "defer_loading": True} if "input_schema" in t else t for t in effective_tools ] - # Bedrock doesn't support .stream() - use create(stream=True) instead - if isinstance(self.anthropic_client, AsyncAnthropicBedrock): + client = self.anthropic_client + response: BetaMessage | None = None + is_bedrock = isinstance(client, AsyncAnthropicBedrock) + invalid_json_failures = 0 + + for _ in range(1 if is_bedrock else 3): + messages_cached: list[BetaMessageParam] = copy.deepcopy(messages) + cache_control = CacheControlEphemeralParam(type="ephemeral") + if messages_cached and messages_cached[-1].get("role") == "user": + content = messages_cached[-1]["content"] + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block["type"] not in ( + "redacted_thinking", + "thinking", + ): + cast("dict[str, object]", block)["cache_control"] = cache_control + try: - response = await self.anthropic_client.beta.messages.create( - model=self.config.model, - system=self.system_prompt if self.system_prompt is not None else Omit(), - max_tokens=self.max_tokens, - messages=messages_cached, - tools=effective_tools, - tool_choice={"type": "auto", "disable_parallel_tool_use": True}, - betas=betas, - ) - messages.append(BetaMessageParam(role="assistant", content=response.content)) - except ModuleNotFoundError: - raise ValueError( - "boto3 is required for AWS Bedrock. Use `pip install hud[bedrock]`" - ) from None - else: - # Regular Anthropic client supports .stream() - response = None - invalid_json_failures = 0 - for _ in range(3): - messages_cached = self._add_prompt_caching(messages) - try: - async with self.anthropic_client.beta.messages.stream( + if isinstance(client, AsyncAnthropicBedrock): + response = await client.beta.messages.create( model=self.config.model, system=self.system_prompt if self.system_prompt is not None else Omit(), max_tokens=self.max_tokens, messages=messages_cached, tools=effective_tools, - tool_choice={"type": "auto", "disable_parallel_tool_use": True}, + tool_choice=tool_choice, + betas=betas, + ) + else: + async with client.beta.messages.stream( + model=self.config.model, + system=self.system_prompt if self.system_prompt is not None else Omit(), + max_tokens=self.max_tokens, + messages=messages_cached, + tools=effective_tools, + tool_choice=tool_choice, betas=betas, ) as stream: - # allow backend to accumulate message content async for _ in stream: pass - # get final message response = await stream.get_final_message() - messages.append( - BetaMessageParam( - role="assistant", - content=response.content, - ) - ) - break - except ValueError as exc: - invalid_json = self._extract_invalid_tool_json(exc) - is_retryable = invalid_json is not None - if not is_retryable: - raise - - invalid_json_failures += 1 - if invalid_json_failures == 1: - logger.warning( - "Claude returned invalid streamed tool JSON; " - "retrying same generation once" - ) - continue - - if invalid_json_failures == 2: - logger.warning( - "Claude returned invalid streamed tool JSON twice; " - "retrying once with INVALID_JSON guidance" - ) - messages.append(self._build_invalid_tool_json_retry_message(invalid_json)) - continue - + messages.append(BetaMessageParam(role="assistant", content=response.content)) + break + except ModuleNotFoundError: + if is_bedrock: + raise ValueError( + "boto3 is required for AWS Bedrock. Use `pip install hud-python[bedrock]`" + ) from None + raise + except ValueError as exc: + message = str(exc) + if is_bedrock or "Unable to parse tool parameter JSON from model." not in message: raise - if response is None: - raise ValueError("Claude response missing after stream retries") + marker = "JSON: " + marker_index = message.find(marker) + invalid_json = ( + "" if marker_index == -1 else message[marker_index + len(marker) :].strip() + ) - # Process response - result = self._result_from_response_blocks(list(response.content)) + invalid_json_failures += 1 + if invalid_json_failures == 1: + logger.warning( + "Claude returned invalid streamed tool JSON; retrying same generation once" + ) + continue + + if invalid_json_failures == 2: + wrapped = json.dumps({"INVALID_JSON": invalid_json}, ensure_ascii=True) + retry_text = ( + "Your previous tool-call arguments were invalid JSON and could not be " + "parsed.\n" + "Retry the same intended tool call once with valid JSON arguments only.\n" + "Ensure all strings are quoted and all arrays/objects are valid JSON.\n" + f"Malformed payload (wrapped): {wrapped}" + ) + logger.warning( + "Claude returned invalid streamed tool JSON twice; " + "retrying once with INVALID_JSON guidance" + ) + messages.append( + BetaMessageParam( + role="user", + content=[BetaTextBlockParam(type="text", text=retry_text)], + ) + ) + continue - return result + raise - async def call_tools( - self, tool_call: MCPToolCall | list[MCPToolCall] | None = None - ) -> list[MCPToolResult]: - """Route Claude provider tools to their backing environment tools.""" - return await call_agent_tools(self, self._claude_native_tools, tool_call) - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[BetaMessageParam]: - """Format tool results into Claude messages. - - Handles EmbeddedResource (PDFs), images, and text content. - """ - citations_enabled = bool( - getattr(self.ctx, "scenario_enable_citations", False) if self.ctx else False - ) + if response is None: + raise ValueError("Claude response missing after stream retries") - # Process each tool result - user_content: list[BetaToolResultBlockParam | BetaRequestDocumentBlockParam] = [] - - for tool_call, result in zip(tool_calls, tool_results, strict=True): - tool_use_id = tool_call.id - if not tool_use_id: - self.hud_console.warning(f"No tool_use_id found for {tool_call.name}") - continue - - # Blocks placed inside the tool_result (text, images) - claude_blocks: list[ - BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam - ] = [] - # Citable document blocks placed as siblings after the tool_result - # so Claude's citation system indexes them properly. - sibling_docs: list[BetaRequestDocumentBlockParam] = [] - - if result.isError: - error_msg = "Tool execution failed" - for content in result.content: - if isinstance(content, types.TextContent): - error_msg = content.text - break - claude_blocks.append(text_to_content_block(f"Error: {error_msg}")) - else: - for content in result.content: - if isinstance(content, types.TextContent): - claude_blocks.append(text_to_content_block(content.text)) - if citations_enabled: - sibling_docs.append( - text_document_block(content.text, title=tool_call.name) - ) - elif isinstance(content, types.ImageContent): - claude_blocks.append( - base64_to_content_block(content.data, content.mimeType) + result = AgentResponse(content="", tool_calls=[], done=True) + text_content = "" + thinking_content = "" + citations: list[dict[str, object]] = [] + + for block in response.content: + match block.type: + case "tool_use": + tool_use = block + mcp_name = self.tools.name_map.get(tool_use.name, tool_use.name) + result.tool_calls.append( + MCPToolCall( + id=tool_use.id, + name=mcp_name, + arguments=dict(tool_use.input), + _meta=mcp_types.RequestParams.Meta.model_validate( + {"enable_citations": self.enable_citations} + ), ) - elif isinstance(content, types.EmbeddedResource): - resource = content.resource - if ( - isinstance(resource, types.BlobResourceContents) - and resource.mimeType == "application/pdf" - ): - claude_blocks.append( - document_to_content_block( - base64_data=resource.blob, - ) - ) - if citations_enabled: - sibling_docs.append( - document_to_content_block( - base64_data=resource.blob, - enable_citations=True, - ) - ) - - user_content.append(tool_use_content_block(tool_use_id, claude_blocks)) - user_content.extend(sibling_docs) - - return [ - BetaMessageParam( - role="user", - content=user_content, - ) - ] - - async def create_user_message(self, text: str) -> BetaMessageParam: - """Create a user message in Claude's format.""" - return BetaMessageParam(role="user", content=text) - - def _convert_tools_for_claude(self) -> None: - """Convert MCP tools to Claude API tools.""" - self.has_computer_tool = False - self.tool_mapping: dict[str, str] = {} - self.claude_tools: list[BetaToolUnionParam] = [] - self._claude_native_tools = {} - self._required_betas: set[str] = set() - self._tool_search_threshold = None - - categorized = self._categorized_tools - - capabilities = self._discover_environment_capabilities(self.get_available_tools()) - self._environment_capabilities = capabilities - provider_backing_tools: set[str] = set() - - for capability in capabilities.values(): - if capability.name not in claude_tools.capabilities: - continue - claude_tool = claude_tools.tool_for_capability(capability, self.model) - if claude_tool is None: - continue - provider_backing_tools.add(capability.tool_name) - provider_name = getattr(claude_tool, "provider_name", claude_tool.name) - self._claude_native_tools[provider_name] = claude_tool - self.tool_mapping[provider_name] = provider_name - self.claude_tools.append(claude_tool.to_params()) - if claude_tool.required_beta: - self._required_betas.add(claude_tool.required_beta) - if claude_tool.capability == "computer": - self.has_computer_tool = True - logger.debug( - "Activated Claude %s capability from env tool %s", - capability.name, - capability.tool_name, - ) + ) + result.done = False + case "text": + text = cast("BetaTextBlock", block) + text_content += text.text + for citation in text.citations or []: + normalized = self._citation(citation) + citations.append(normalized.model_dump(exclude={"provider_data"})) + case "thinking": + thinking = block + if thinking.thinking: + if thinking_content: + thinking_content += "\n" + thinking_content += thinking.thinking + case _: + continue - configured_hosted = select_hosted_tools( - self.config.hosted_tools, - tool_type=ClaudeHostedTool, - model=self.model, - ) - for hosted_tool in configured_hosted: - self.claude_tools.append(hosted_tool.to_params()) # type: ignore[arg-type] - required_beta = getattr(hosted_tool, "required_beta", None) - if required_beta: - self._required_betas.add(required_beta) - if isinstance(hosted_tool, ClaudeToolSearchTool): - self._tool_search_threshold = hosted_tool.threshold - - # Process generic tools - for tool in categorized.generic: - if tool.name in provider_backing_tools: - continue - if tool.description is None or tool.inputSchema is None: - raise ValueError( - cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema. - Add these by: - 1. Adding a docstring to your @mcp.tool decorated function for the description - 2. Using pydantic Field() annotations on function parameters for the schema - """) - ) + result.content = text_content + result.citations = citations + if thinking_content: + result.reasoning = thinking_content - claude_tool = BetaToolParam( - name=tool.name, - description=tool.description, - input_schema=tool.inputSchema, - eager_input_streaming=True, - ) - self.tool_mapping[tool.name] = tool.name - self.claude_tools.append(claude_tool) + return result - # Log actual tools being used - tool_names = sorted(self.tool_mapping.keys()) - self.console.info( - f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}" + @staticmethod + def _citation(citation: BetaTextCitation) -> Citation: + match citation.type: + case "char_location": + char_location = citation + citation_type = "document_citation" + text = char_location.cited_text + source = str(char_location.document_index) + title = char_location.document_title + start_index = char_location.start_char_index + end_index = char_location.end_char_index + case "page_location": + page_location = citation + citation_type = "document_citation" + text = page_location.cited_text + source = str(page_location.document_index) + title = page_location.document_title + start_index = None + end_index = None + case "content_block_location": + block_location = citation + citation_type = "document_citation" + text = block_location.cited_text + source = str(block_location.document_index) + title = block_location.document_title + start_index = block_location.start_block_index + end_index = block_location.end_block_index + case "search_result_location": + search_result = citation + citation_type = "search_result_location" + text = search_result.cited_text + source = search_result.source + title = search_result.title + start_index = search_result.start_block_index + end_index = search_result.end_block_index + case "web_search_result_location": + web_result = citation + citation_type = "web_search_result_location" + text = web_result.cited_text + source = web_result.url + title = web_result.title + start_index = None + end_index = None + + return Citation( + type=citation_type, + text=text, + source=source, + title=title, + start_index=start_index, + end_index=end_index, ) - - def _add_prompt_caching(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: - """Add prompt caching to messages.""" - messages_cached = copy.deepcopy(messages) - cache_control = CacheControlEphemeralParam(type="ephemeral") - - # Mark last user message with cache control - if ( - messages_cached - and isinstance(messages_cached[-1], dict) - and messages_cached[-1].get("role") == "user" - ): - last_content = messages_cached[-1]["content"] - # Content is formatted to be list of ContentBlock in format_blocks and format_message - if isinstance(last_content, list): - for block in last_content: - # Only add cache control to dict-like block types that support it - if isinstance(block, dict): - match block["type"]: - case "redacted_thinking" | "thinking": - pass - case _: - block["cache_control"] = cache_control - - return messages_cached - - -def base64_to_content_block( - base64: str, - media_type: str = "image/png", -) -> BetaImageBlockParam: - """Convert base64 image to Claude content block.""" - return BetaImageBlockParam( - type="image", - source=BetaBase64ImageSourceParam( - type="base64", - media_type=cast( - "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']", - media_type, - ), - data=base64, - ), - ) - - -def text_to_content_block(text: str) -> BetaTextBlockParam: - """Convert text to Claude content block.""" - return {"type": "text", "text": text} - - -def text_document_block(text: str, *, title: str | None = None) -> BetaRequestDocumentBlockParam: - """Wrap plain text as a citable document block.""" - block = BetaRequestDocumentBlockParam( - type="document", - source=BetaPlainTextSourceParam( - type="text", - media_type="text/plain", - data=text, - ), - citations={"enabled": True}, - ) - if title: - block["title"] = title - return block - - -def document_to_content_block( - base64_data: str, *, enable_citations: bool = False -) -> BetaRequestDocumentBlockParam: - """Convert base64 PDF to Claude document content block.""" - block = BetaRequestDocumentBlockParam( - type="document", - source=BetaBase64PDFSourceParam( - type="base64", - media_type="application/pdf", - data=base64_data, - ), - ) - if enable_citations: - block["citations"] = {"enabled": True} - return block - - -def tool_use_content_block( - tool_use_id: str, - content: Sequence[BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam], -) -> BetaToolResultBlockParam: - """Create tool result content block.""" - return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content} # pyright: ignore[reportReturnType] diff --git a/hud/agents/claude/tools/__init__.py b/hud/agents/claude/tools/__init__.py index ff341fa43..16796f567 100644 --- a/hud/agents/claude/tools/__init__.py +++ b/hud/agents/claude/tools/__init__.py @@ -2,58 +2,63 @@ from __future__ import annotations -from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar -from hud.agents.tools import AgentToolRegistry +from anthropic.types.beta import BetaToolUnionParam -from .base import ClaudeTool +from hud.agents.tools import AgentTools + +from .base import ClaudeFunctionTool, ClaudeTool from .coding import ClaudeBashTool, ClaudeTextEditorTool from .computer import ClaudeComputerTool from .hosted import ClaudeHostedTool, ClaudeToolSearchTool, ClaudeWebFetchTool, ClaudeWebSearchTool from .memory import ClaudeMemoryTool +if TYPE_CHECKING: + from collections.abc import Mapping + + from hud.agents.tools import AgentTool + -@dataclass(frozen=True) -class ClaudeToolRegistry(AgentToolRegistry[ClaudeTool]): - """Registry for Claude harness tools.""" +class ClaudeAgentTools(AgentTools[ClaudeTool, BetaToolUnionParam]): + """Prepared Claude tool state for a run.""" - tool_classes: tuple[type[ClaudeTool], ...] = ( + native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = ( ClaudeComputerTool, ClaudeBashTool, ClaudeTextEditorTool, ClaudeMemoryTool, ) - name_fallbacks: dict[str, tuple[str, ...]] = field( - default_factory=lambda: { - "computer": ("computer", "anthropic_computer", "computer_anthropic"), - "shell": ("bash",), - "editor": ("edit", "str_replace_based_edit_tool", "text_editor"), - "memory": ("memory",), - } - ) + function_tool_class = ClaudeFunctionTool + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = { + "computer": ("computer", "anthropic_computer", "computer_anthropic"), + "shell": ("bash",), + "editor": ("edit", "str_replace_based_edit_tool", "text_editor"), + "memory": ("memory",), + } - @property - def capabilities(self) -> frozenset[str]: - return frozenset(cls.capability for cls in self.tool_classes) - - @property - def provider_tool_names(self) -> frozenset[str]: - return frozenset(cls.name for cls in self.tool_classes) + def __init__(self) -> None: + super().__init__() + self.required_betas: set[str] = set() + def prepare(self, **kwargs: Any) -> None: + super().prepare(**kwargs) + self.required_betas = { + required_beta for tool in self.values() if (required_beta := tool.required_beta) + } -claude_tools = ClaudeToolRegistry() + @property + def tool_search_threshold(self) -> int | None: + for hosted_tool in self.hosted_tools: + if isinstance(hosted_tool, ClaudeToolSearchTool): + return hosted_tool.threshold + return None __all__ = [ - "ClaudeBashTool", - "ClaudeComputerTool", + "ClaudeAgentTools", "ClaudeHostedTool", - "ClaudeMemoryTool", - "ClaudeTextEditorTool", - "ClaudeTool", - "ClaudeToolRegistry", "ClaudeToolSearchTool", "ClaudeWebFetchTool", "ClaudeWebSearchTool", - "claude_tools", ] diff --git a/hud/agents/claude/tools/base.py b/hud/agents/claude/tools/base.py index ee4b4820e..0cd353cad 100644 --- a/hud/agents/claude/tools/base.py +++ b/hud/agents/claude/tools/base.py @@ -2,27 +2,179 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from inspect import cleandoc +from typing import TYPE_CHECKING, Any, Literal, cast -from hud.agents import tools as _agent_tools -from hud.agents.tools import AgentTool, AgentToolSpec, CallTool +import mcp.types as types +from anthropic.types.beta import ( + BetaBase64ImageSourceParam, + BetaBase64PDFSourceParam, + BetaImageBlockParam, + BetaMessageParam, + BetaPlainTextSourceParam, + BetaRequestDocumentBlockParam, + BetaTextBlockParam, + BetaToolParam, + BetaToolResultBlockParam, +) + +from hud.agents.tools import AgentTool, AgentToolSpec if TYPE_CHECKING: from anthropic.types.beta import BetaToolUnionParam - from hud.types import MCPToolResult + from hud.types import MCPToolCall, MCPToolResult else: BetaToolUnionParam = Any -ClaudeToolSpec = AgentToolSpec -call_tool = _agent_tools.call_tool +ClaudeImageMediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"] +ClaudeToolResultContent = BetaTextBlockParam | BetaImageBlockParam | BetaRequestDocumentBlockParam + + +@dataclass(frozen=True) +class ClaudeToolSpec(AgentToolSpec): + """Claude provider tool definition.""" + + beta: str | None = None -class ClaudeTool(AgentTool["BetaToolUnionParam"], ABC): +class ClaudeTool(AgentTool["BetaToolUnionParam"]): """Agent-side Claude provider tool backed by an environment tool.""" - @abstractmethod - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - """Execute against the environment tool using the agent-provided caller.""" - ... + def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None: + super().__init__(env_tool_name=env_tool_name, spec=spec) + self.spec: ClaudeToolSpec = spec + + @property + def required_beta(self) -> str | None: + return self.spec.beta + + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> BetaMessageParam | None: + tool_use_id = call.id + if not tool_use_id: + return None + + result_content = result.content + if result.isError: + error_msg = next( + ( + content.text + for content in result.content + if isinstance(content, types.TextContent) + ), + "Tool execution failed", + ) + result_content = [types.TextContent(type="text", text=f"Error: {error_msg}")] + + claude_blocks: list[ClaudeToolResultContent] = [] + sibling_docs: list[BetaRequestDocumentBlockParam] = [] + enable_citations = bool(getattr(call.meta, "enable_citations", False)) + for content in result_content: + citation_doc = None + match content: + case types.TextContent(): + block = BetaTextBlockParam(type="text", text=content.text) + if enable_citations and not result.isError: + citation_doc = BetaRequestDocumentBlockParam( + type="document", + source=BetaPlainTextSourceParam( + type="text", + media_type="text/plain", + data=content.text, + ), + title=call.name, + citations={"enabled": True}, + ) + case types.ImageContent(): + block = BetaImageBlockParam( + type="image", + source=BetaBase64ImageSourceParam( + type="base64", + media_type=cast("ClaudeImageMediaType", content.mimeType), + data=content.data, + ), + ) + case types.EmbeddedResource( + resource=types.BlobResourceContents(mimeType="application/pdf") as resource + ): + block = BetaRequestDocumentBlockParam( + type="document", + source=BetaBase64PDFSourceParam( + type="base64", + media_type="application/pdf", + data=resource.blob, + ), + ) + if enable_citations and not result.isError: + citation_doc = BetaRequestDocumentBlockParam( + type="document", + source=block["source"], + citations={"enabled": True}, + ) + case _: + raise ValueError(f"Unknown content block type: {type(content)}") + claude_blocks.append(block) + if citation_doc is not None: + sibling_docs.append(citation_doc) + + return BetaMessageParam( + role="user", + content=[ + BetaToolResultBlockParam( + type="tool_result", + tool_use_id=tool_use_id, + content=claude_blocks, + ), + *sibling_docs, + ], + ) + + +class ClaudeFunctionTool(ClaudeTool): + """Regular environment tool exposed as a Claude function tool.""" + + name = "function" + capability = "function" + + def __init__( + self, + *, + env_tool_name: str, + description: str, + input_schema: dict[str, Any], + ) -> None: + super().__init__( + env_tool_name=env_tool_name, + spec=ClaudeToolSpec(api_type="function", api_name=env_tool_name), + ) + self.description = description + self.input_schema = input_schema + + @classmethod + def from_tool(cls, tool: types.Tool) -> ClaudeFunctionTool: + if tool.description is None: + raise ValueError( + cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema. + Add these by: + 1. Adding a docstring to your @mcp.tool decorated function for the description + 2. Using pydantic Field() annotations on function parameters for the schema + """) + ) + return cls( + env_tool_name=tool.name, + description=tool.description, + input_schema=tool.inputSchema, + ) + + @property + def provider_name(self) -> str: + return self.env_tool_name + + def to_params(self) -> BetaToolUnionParam: + return BetaToolParam( + name=self.provider_name, + description=self.description, + input_schema=self.input_schema, + eager_input_streaming=True, + ) diff --git a/hud/agents/claude/tools/coding.py b/hud/agents/claude/tools/coding.py index f9b4331dd..fc66467c8 100644 --- a/hud/agents/claude/tools/coding.py +++ b/hud/agents/claude/tools/coding.py @@ -8,7 +8,7 @@ from hud.types import MCPToolResult -from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool +from .base import ClaudeTool, ClaudeToolSpec if TYPE_CHECKING: from anthropic.types.beta import ( @@ -16,6 +16,8 @@ BetaToolTextEditor20250728Param, ) + from hud.agents.tools.base import CallTool + CLAUDE_BASH_SPEC = ClaudeToolSpec( api_type="bash_20250124", @@ -29,30 +31,18 @@ ), ) -CLAUDE_TEXT_EDITOR_SPECS: tuple[ClaudeToolSpec, ...] = ( - ClaudeToolSpec( - api_type="text_editor_20250728", - api_name="str_replace_based_edit_tool", - supported_models=( - "*claude-opus-4-7*", - "*claude-opus-4-6*", - "*claude-sonnet-4-5*", - "*claude-sonnet-4-6*", - "*claude-haiku-4-5*", - ), +CLAUDE_TEXT_EDITOR_SPEC = ClaudeToolSpec( + api_type="text_editor_20250728", + api_name="str_replace_based_edit_tool", + supported_models=( + "*claude-opus-4-7*", + "*claude-opus-4-6*", + "*claude-sonnet-4-5*", + "*claude-sonnet-4-6*", + "*claude-haiku-4-5*", ), ) -CLAUDE_TEXT_EDITOR_SPEC = CLAUDE_TEXT_EDITOR_SPECS[0] - -CLAUDE_TEXT_EDITOR_NAMES = { - "text_editor_20250728": "str_replace_based_edit_tool", -} - -CLAUDE_TEXT_EDITOR_COMMANDS = { - "text_editor_20250728": frozenset({"view", "create", "str_replace", "insert"}), -} - class ClaudeBashTool(ClaudeTool): """Claude bash provider tool backed by an environment shell tool.""" @@ -81,7 +71,7 @@ def to_params(self) -> BetaToolBash20250124Param: async def execute( self, - caller: CallTool, + call_tool: CallTool, arguments: dict[str, Any], ) -> MCPToolResult: if not arguments.get("restart") and "command" not in arguments: @@ -94,7 +84,7 @@ async def execute( ], isError=True, ) - return await call_tool(caller, self.env_tool_name, arguments) + return await super().execute(call_tool, arguments) class ClaudeTextEditorTool(ClaudeTool): @@ -105,9 +95,8 @@ class ClaudeTextEditorTool(ClaudeTool): @classmethod def default_spec(cls, model: str) -> ClaudeToolSpec | None: - for spec in CLAUDE_TEXT_EDITOR_SPECS: - if spec.supports_model(model): - return spec + if CLAUDE_TEXT_EDITOR_SPEC.supports_model(model): + return CLAUDE_TEXT_EDITOR_SPEC return None def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None: @@ -115,7 +104,7 @@ def __init__(self, *, env_tool_name: str, spec: ClaudeToolSpec) -> None: @property def provider_name(self) -> str: - return CLAUDE_TEXT_EDITOR_NAMES.get(self.spec.api_type, self.spec.api_name) + return self.spec.api_name def to_params(self) -> BetaToolTextEditor20250728Param: return cast( @@ -128,25 +117,10 @@ def to_params(self) -> BetaToolTextEditor20250728Param: async def execute( self, - caller: CallTool, + call_tool: CallTool, arguments: dict[str, Any], ) -> MCPToolResult: - command = arguments.get("command") - allowed_commands = CLAUDE_TEXT_EDITOR_COMMANDS.get(self.spec.api_type) - if allowed_commands is not None and command not in allowed_commands: - return MCPToolResult( - content=[ - TextContent( - type="text", - text=( - f"{self.spec.api_type} does not support command {command!r}. " - f"Supported commands: {', '.join(sorted(allowed_commands))}" - ), - ) - ], - isError=True, - ) - return await call_tool(caller, self.env_tool_name, _claude_editor_arguments(arguments)) + return await super().execute(call_tool, _claude_editor_arguments(arguments)) def _claude_editor_arguments(arguments: dict[str, Any]) -> dict[str, Any]: @@ -170,12 +144,3 @@ def _claude_editor_arguments(arguments: dict[str, Any]) -> dict[str, Any]: } case _: return dict(arguments) - - -__all__ = [ - "CLAUDE_BASH_SPEC", - "CLAUDE_TEXT_EDITOR_SPEC", - "CLAUDE_TEXT_EDITOR_SPECS", - "ClaudeBashTool", - "ClaudeTextEditorTool", -] diff --git a/hud/agents/claude/tools/computer.py b/hud/agents/claude/tools/computer.py index 6953e2fde..7ca775c15 100644 --- a/hud/agents/claude/tools/computer.py +++ b/hud/agents/claude/tools/computer.py @@ -9,13 +9,19 @@ import base64 import logging from io import BytesIO -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast -from mcp.types import ImageContent, TextContent +from mcp.types import ImageContent +from hud.agents.tools.computer import ( + computer_error_result, + computer_tool_info, + execute_computer_calls, + first_image_data, +) from hud.types import MCPToolResult -from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool +from .base import ClaudeTool, ClaudeToolSpec from .settings import claude_tool_settings if TYPE_CHECKING: @@ -25,6 +31,7 @@ ) from hud.agents.tools import EnvironmentCapability + from hud.agents.tools.base import CallTool logger = logging.getLogger(__name__) @@ -86,8 +93,6 @@ ), ) -_AUTO_SCREENSHOT_OFF_SPECS = {"computer_20251124"} - class ClaudeComputerTool(ClaudeTool): """Translate Claude native computer calls into environment computer calls.""" @@ -107,62 +112,36 @@ def __init__( *, env_tool_name: str, spec: ClaudeToolSpec, - model: str, display_width: int, display_height: int, - schema: Literal["hud", "anthropic"], ) -> None: - super().__init__(env_tool_name=env_tool_name, spec=self._resolve_spec(spec, model)) + super().__init__(env_tool_name=env_tool_name, spec=spec) self.display_width = display_width self.display_height = display_height - self.schema = schema @classmethod def from_capability( cls, capability: EnvironmentCapability, - spec: ClaudeToolSpec, model: str, - ) -> ClaudeComputerTool: - tool = capability.tool - props = tool.inputSchema.get("properties", {}) if isinstance(tool.inputSchema, dict) else {} - schema: Literal["hud", "anthropic"] = ( - "anthropic" if {"coordinate", "scroll_direction"} & set(props) else "hud" - ) - - metadata_resolution = capability.metadata.get("resolution", {}) - if not isinstance(metadata_resolution, dict): - metadata_resolution = {} - resolution = (tool.meta or {}).get("resolution", {}) if tool.meta else {} - display_width = int( - metadata_resolution.get("width") - or resolution.get("width") - or claude_tool_settings.COMPUTER_WIDTH - ) - display_height = int( - metadata_resolution.get("height") - or resolution.get("height") - or claude_tool_settings.COMPUTER_HEIGHT + ) -> ClaudeComputerTool | None: + spec = cls.default_spec(model) + if spec is None: + return None + + computer_info = computer_tool_info( + capability.tool, + default_width=claude_tool_settings.COMPUTER_WIDTH, + default_height=claude_tool_settings.COMPUTER_HEIGHT, ) return cls( env_tool_name=capability.tool_name, spec=spec, - model=model, - display_width=display_width, - display_height=display_height, - schema=schema, + display_width=computer_info.display_width, + display_height=computer_info.display_height, ) - @staticmethod - def _resolve_spec(spec: ClaudeToolSpec, model: str) -> ClaudeToolSpec: - if spec.api_type and spec.api_type.startswith("computer_"): - return spec - for candidate in CLAUDE_COMPUTER_SPECS: - if candidate.supports_model(model): - return candidate - return spec - def to_params( self, ) -> BetaToolComputerUse20250124Param | BetaToolComputerUse20251124Param: @@ -191,47 +170,20 @@ def to_params( async def execute( self, - caller: CallTool, - arguments: dict[str, Any], - ) -> MCPToolResult: - if self.schema == "anthropic": - return await self._call_env(caller, self._as_anthropic_arguments(arguments)) - return await self._call_env_tool(caller, arguments) - - async def _call_env( - self, - caller: CallTool, - arguments: dict[str, Any], - ) -> MCPToolResult: - return await call_tool(caller, self.env_tool_name, arguments) - - async def _call_env_tool( - self, - caller: CallTool, + call_tool: CallTool, arguments: dict[str, Any], ) -> MCPToolResult: action = arguments.get("action") if action == "zoom": - return await self._zoom(caller, arguments) - - calls = self._env_calls(arguments) - result = MCPToolResult(content=[], isError=False) - for call in calls: - result = await self._call_env(caller, call) - if result.isError: - return result - return result - - def _as_anthropic_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: - args = dict(arguments) - if ( - self.spec.api_type in _AUTO_SCREENSHOT_OFF_SPECS - and args.get("action") != "screenshot" - and "take_screenshot_on_click" not in args - ): - args["take_screenshot_on_click"] = False - return args + return await self._zoom(call_tool, arguments) + + return await execute_computer_calls( + call_tool, + env_tool_name=self.env_tool_name, + calls=self._env_calls(arguments), + ensure_screenshot=False, + ) def _env_calls(self, arguments: dict[str, Any]) -> list[dict[str, Any]]: action = arguments.get("action") @@ -239,8 +191,10 @@ def _env_calls(self, arguments: dict[str, Any]) -> list[dict[str, Any]]: text = arguments.get("text") def xy() -> tuple[int | None, int | None]: - if isinstance(coordinate, list) and len(coordinate) >= 2: - return coordinate[0], coordinate[1] + if isinstance(coordinate, list): + coords = cast("list[Any]", coordinate) + if len(coords) >= 2: + return int(coords[0]), int(coords[1]) return None, None if action == "screenshot": @@ -317,17 +271,21 @@ def xy() -> tuple[int | None, int | None]: ] if action in ("left_click_drag", "drag"): start = arguments.get("start_coordinate") - path = [] - if isinstance(start, list) and len(start) >= 2: - path.append({"x": start[0], "y": start[1]}) - if isinstance(coordinate, list) and len(coordinate) >= 2: - if not path: - return [ - {"action": "mouse_down", "button": "left"}, - {"action": "move", "x": coordinate[0], "y": coordinate[1]}, - {"action": "mouse_up", "button": "left"}, - ] - path.append({"x": coordinate[0], "y": coordinate[1]}) + path: list[dict[str, Any]] = [] + if isinstance(start, list): + start_coords = cast("list[Any]", start) + if len(start_coords) >= 2: + path.append({"x": start_coords[0], "y": start_coords[1]}) + if isinstance(coordinate, list): + end_coords = cast("list[Any]", coordinate) + if len(end_coords) >= 2: + if not path: + return [ + {"action": "mouse_down", "button": "left"}, + {"action": "move", "x": end_coords[0], "y": end_coords[1]}, + {"action": "mouse_up", "button": "left"}, + ] + path.append({"x": end_coords[0], "y": end_coords[1]}) return [{"action": "drag", "path": path, "hold_keys": self._hold_keys(text)}] if action == "wait": duration = arguments.get("duration") or 0 @@ -351,28 +309,23 @@ def xy() -> tuple[int | None, int | None]: async def _zoom( self, - caller: CallTool, + call_tool: CallTool, arguments: dict[str, Any], ) -> MCPToolResult: region = arguments.get("region") - if not isinstance(region, (list, tuple)) or len(region) != 4: - return MCPToolResult( - content=[TextContent(type="text", text="region must be [x0, y0, x1, y1]")], - isError=True, - ) + region_value = cast("list[Any] | tuple[Any, ...]", region) + if not isinstance(region, (list, tuple)) or len(region_value) != 4: + return computer_error_result("region must be [x0, y0, x1, y1]") - screenshot = await self._call_env(caller, {"action": "screenshot"}) + screenshot = await super().execute(call_tool, {"action": "screenshot"}) if screenshot.isError: return screenshot - image_data = _first_image(screenshot) + image_data = first_image_data(screenshot) if image_data is None: - return MCPToolResult( - content=[TextContent(type="text", text="screenshot returned no image")], - isError=True, - ) + return computer_error_result("screenshot returned no image") try: - x0, y0, x1, y1 = (int(v) for v in region) + x0, y0, x1, y1 = (int(v) for v in region_value) image = ImageContent( type="image", mimeType="image/png", @@ -381,7 +334,7 @@ async def _zoom( return MCPToolResult(content=[image], isError=False) except Exception as exc: logger.warning("Claude computer zoom failed: %s", exc) - return MCPToolResult(content=[TextContent(type="text", text=str(exc))], isError=True) + return computer_error_result(str(exc)) @staticmethod def _keys(text: str | None) -> list[str]: @@ -419,13 +372,6 @@ def _map_key(key: str) -> str: return ANTHROPIC_TO_CLA_KEYS.get(key, ANTHROPIC_TO_CLA_KEYS.get(key.capitalize(), key.lower())) -def _first_image(result: MCPToolResult) -> str | None: - for block in result.content or []: - if isinstance(block, ImageContent): - return block.data - return None - - def _crop_png(image_data: str, region: tuple[int, int, int, int]) -> str: from PIL import Image # type: ignore[import-not-found] @@ -434,6 +380,3 @@ def _crop_png(image_data: str, region: tuple[int, int, int, int]) -> str: buffer = BytesIO() crop.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("ascii") - - -__all__ = ["CLAUDE_COMPUTER_SPECS", "ClaudeComputerTool"] diff --git a/hud/agents/claude/tools/hosted.py b/hud/agents/claude/tools/hosted.py index f9b19a593..050afedaa 100644 --- a/hud/agents/claude/tools/hosted.py +++ b/hud/agents/claude/tools/hosted.py @@ -112,11 +112,3 @@ def _validate_domain_filters( ) -> None: if allowed_domains and blocked_domains: raise ValueError("Use either allowed_domains or blocked_domains, not both.") - - -__all__ = [ - "ClaudeHostedTool", - "ClaudeToolSearchTool", - "ClaudeWebFetchTool", - "ClaudeWebSearchTool", -] diff --git a/hud/agents/claude/tools/memory.py b/hud/agents/claude/tools/memory.py index 53d8c42d5..373c4f3c7 100644 --- a/hud/agents/claude/tools/memory.py +++ b/hud/agents/claude/tools/memory.py @@ -2,15 +2,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, cast -from .base import CallTool, ClaudeTool, ClaudeToolSpec, call_tool +from .base import ClaudeTool, ClaudeToolSpec if TYPE_CHECKING: from anthropic.types.beta import BetaToolUnionParam - from hud.types import MCPToolResult - CLAUDE_MEMORY_SPEC = ClaudeToolSpec( api_type="memory_20250818", @@ -48,13 +46,3 @@ def to_params(self) -> BetaToolUnionParam: "name": self.name, }, ) - - async def execute( - self, - caller: CallTool, - arguments: dict[str, Any], - ) -> MCPToolResult: - return await call_tool(caller, self.env_tool_name, arguments) - - -__all__ = ["CLAUDE_MEMORY_SPEC", "ClaudeMemoryTool"] diff --git a/hud/agents/claude/tools/settings.py b/hud/agents/claude/tools/settings.py index 041c436c4..9a301d006 100644 --- a/hud/agents/claude/tools/settings.py +++ b/hud/agents/claude/tools/settings.py @@ -34,5 +34,3 @@ class ClaudeToolSettings(BaseSettings): claude_tool_settings = ClaudeToolSettings() - -__all__ = ["ClaudeToolSettings", "claude_tool_settings"] diff --git a/hud/agents/gateway.py b/hud/agents/gateway.py index 4d0973f8f..c78db083b 100644 --- a/hud/agents/gateway.py +++ b/hud/agents/gateway.py @@ -2,10 +2,44 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any +import httpx +from openai import AsyncOpenAI +from pydantic import BaseModel, Field -def build_gateway_client(provider: str) -> Any: +from hud.settings import settings +from hud.types import AgentType + +if TYPE_CHECKING: + from typing import TypeAlias + + from anthropic import AsyncAnthropic, AsyncAnthropicBedrock + from google.genai import Client as GenaiClient + + from hud.agents.base import MCPAgent + + GatewayClient: TypeAlias = AsyncAnthropic | AsyncAnthropicBedrock | GenaiClient | AsyncOpenAI + + +class GatewayProviderInfo(BaseModel): + name: str | None = None + default_sdk_agent_type: str | None = None + + +class GatewayModelInfo(BaseModel): + id: str | None = None + name: str | None = None + model_name: str | None = None + sdk_agent_type: str | None = None + provider: GatewayProviderInfo = Field(default_factory=GatewayProviderInfo) + + +class GatewayModelsResponse(BaseModel): + models: list[GatewayModelInfo] + + +def build_gateway_client(provider: str) -> GatewayClient: """Build a client configured for HUD gateway routing. Args: @@ -14,10 +48,10 @@ def build_gateway_client(provider: str) -> Any: Returns: Configured async client for the provider. """ - from hud.settings import settings - provider = provider.lower() + # Anthropic and Gemini SDKs are optional extras; keep those imports on the + # provider branch so importing gateway utilities does not require both. if provider == "anthropic": from anthropic import AsyncAnthropic @@ -37,6 +71,74 @@ def build_gateway_client(provider: str) -> Any: ) # OpenAI-compatible (openai, azure, together, groq, fireworks, etc.) - from openai import AsyncOpenAI - return AsyncOpenAI(api_key=settings.api_key, base_url=settings.hud_gateway_url) + + +def _fetch_gateway_models() -> list[GatewayModelInfo]: + """Fetch available models from HUD API.""" + if not settings.api_key: + return [] + + try: + resp = httpx.get( + f"{settings.hud_api_url}/models/", + headers={"Authorization": f"Bearer {settings.api_key}"}, + timeout=10.0, + ) + resp.raise_for_status() + payload: object = resp.json() + if not isinstance(payload, dict) or "models" not in payload: + return [] + return GatewayModelsResponse.model_validate(payload).models + except Exception: + return [] + + +def create_agent(model: str, **kwargs: Any) -> MCPAgent[Any]: + """Create an agent routed through the HUD gateway. + + For direct API access with provider API keys, instantiate the agent classes directly. + """ + agent_type = next((candidate for candidate in AgentType if candidate.value == model), None) + if agent_type is not None: + model_id = model + provider_name = agent_type.gateway_provider + else: + for gateway_model in _fetch_gateway_models(): + if model in ( + gateway_model.id, + gateway_model.name, + gateway_model.model_name, + ): + agent_str = ( + gateway_model.sdk_agent_type or gateway_model.provider.default_sdk_agent_type + ) + if agent_str == "operator": + raise ValueError( + "Operator agent is no longer supported; use openai with a supported " + "OpenAI computer model." + ) + if agent_str == "gemini_cua": + raise ValueError( + "Gemini CUA agent is no longer supported; use gemini with a supported " + "Gemini computer-use model." + ) + if not isinstance(agent_str, str): + raise ValueError(f"Model '{model}' has invalid agent type metadata") + + agent_type = AgentType(agent_str) + model_id = gateway_model.model_name or model + provider_name = gateway_model.provider.name or "openai" + break + else: + raise ValueError(f"Model '{model}' not found") + + client = build_gateway_client(provider_name) + kwargs.setdefault("model", model_id) + if agent_type == AgentType.OPENAI_COMPATIBLE: + kwargs.setdefault("openai_client", client) + else: + kwargs.setdefault("model_client", client) + kwargs.setdefault("validate_api_key", False) + + return agent_type.cls.create(**kwargs) diff --git a/hud/agents/gemini/agent.py b/hud/agents/gemini/agent.py index 9ae2d3c2a..d4f83480f 100644 --- a/hud/agents/gemini/agent.py +++ b/hud/agents/gemini/agent.py @@ -2,40 +2,30 @@ from __future__ import annotations +import base64 import logging -from typing import Any, ClassVar, cast +from functools import cached_property +from typing import Any, cast import mcp.types as types from google import genai from google.genai import types as genai_types +from hud.agents import gateway from hud.agents.base import MCPAgent -from hud.agents.tools import ( - EnvironmentCapability, - call_agent_tools, - capabilities_metadata_from_context, - discover_environment_capabilities, - select_hosted_tools, -) -from hud.agents.types import GeminiConfig, GeminiCreateParams +from hud.agents.types import GeminiConfig from hud.settings import settings -from hud.tools.computer import computer_settings -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult -from hud.utils.hud_console import HUDConsole +from hud.tools.types import Citation +from hud.types import AgentResponse from hud.utils.types import with_signature -from .tools import ( - GeminiComputerTool, - GeminiHostedTool, - GeminiTool, - gemini_tools, - normalize_gemini_computer_use_args, -) +from .settings import gemini_agent_settings +from .tools import GeminiAgentTools logger = logging.getLogger(__name__) -class GeminiAgent(MCPAgent): +class GeminiAgent(MCPAgent[genai_types.Content]): """ Gemini agent that uses MCP servers for tool execution. @@ -43,38 +33,25 @@ class GeminiAgent(MCPAgent): tools through MCP servers instead of direct implementation. """ - metadata: ClassVar[dict[str, Any] | None] = None - config_cls: ClassVar[type[BaseAgentConfig]] = GeminiConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for Gemini.""" - return AgentType.GEMINI - - @with_signature(GeminiCreateParams) + @with_signature(GeminiConfig) @classmethod - def create(cls, **kwargs: Any) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride] - return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value] + def create(cls, **kwargs: object) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride] + return cls(GeminiConfig.model_validate(kwargs)) - def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> None: - super().__init__(params, **kwargs) + def __init__(self, config: GeminiConfig | None = None) -> None: + config = config or GeminiConfig() + super().__init__(config) self.config: GeminiConfig model_client = self.config.model_client if model_client is None: if settings.api_key: - from hud.agents.gateway import build_gateway_client - - model_client = build_gateway_client("gemini") + model_client = gateway.build_gateway_client("gemini") elif settings.gemini_api_key: model_client = genai.Client(api_key=settings.gemini_api_key) if self.config.validate_api_key: try: - list( - model_client.models.list( - config=genai_types.ListModelsConfig(page_size=1) - ) - ) + next(iter(model_client.models.list()), None) except Exception as e: raise ValueError(f"Gemini API key is invalid: {e}") from e else: @@ -87,90 +64,76 @@ def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> N " access" ) - self.gemini_client: genai.Client = model_client + self.gemini_client: genai.Client = cast("genai.Client", model_client) self.temperature = self.config.temperature self.top_p = self.config.top_p self.top_k = self.config.top_k self.max_output_tokens = self.config.max_output_tokens self.thinking_level = self.config.thinking_level self.include_thoughts = self.config.include_thoughts - self.hud_console = HUDConsole(logger=logger) - # Track mapping from Gemini tool names to MCP tool names - self._gemini_to_mcp_tool_map: dict[str, str] = {} - self._computer_tool_name: str | None = None - self._gemini_native_tools: dict[str, GeminiTool] = {} - self._environment_capabilities: dict[str, EnvironmentCapability] = {} self.excluded_predefined_functions = list(self.config.excluded_predefined_functions) self.max_recent_turn_with_screenshots = ( - computer_settings.GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS - ) - self.gemini_tools: genai_types.ToolListUnion = [] - - def _on_tools_ready(self) -> None: - """Build Gemini-specific tool mappings after tools are discovered.""" - self._convert_tools_for_gemini() - - def _discover_environment_capabilities( - self, tools: list[types.Tool] - ) -> dict[str, EnvironmentCapability]: - return discover_environment_capabilities( - tools, - env_metadata=capabilities_metadata_from_context(self.ctx), - name_fallbacks=gemini_tools.name_fallbacks, + gemini_agent_settings.MAX_RECENT_TURN_WITH_SCREENSHOTS ) - async def get_system_messages(self) -> list[genai_types.Content]: - """No system messages for Gemini because applied in get_response""" - return [] - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]: - """Format messages for Gemini.""" - # Convert MCP content types to Gemini content types - gemini_parts: list[genai_types.Part] = [] - - for block in blocks: - if isinstance(block, types.TextContent): - gemini_parts.append(genai_types.Part(text=block.text)) - elif isinstance(block, types.ImageContent): - # Convert MCP ImageContent to Gemini format - # Need to decode base64 string to bytes - import base64 - - image_bytes = base64.b64decode(block.data) - gemini_parts.append( - genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType) - ) - else: - # For other types, try to handle but log a warning - self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning") + @cached_property + def tools(self) -> GeminiAgentTools: + return GeminiAgentTools( + excluded_predefined_functions=self.excluded_predefined_functions, + ) - return [genai_types.Content(role="user", parts=gemini_parts)] + async def format_messages( + self, messages: list[types.PromptMessage] + ) -> list[genai_types.Content]: + """Format MCP prompt messages for Gemini.""" + return [ + genai_types.Content( + role="model" if str(message.role) == "assistant" else str(message.role), + parts=[_format_content(message.content)], + ) + for message in messages + ] - async def get_response(self, messages: list[genai_types.Content]) -> InferenceResult: + async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse: """Get response from Gemini including any tool calls.""" - self._remove_old_screenshots(messages) - tools = self.gemini_tools + # Drop screenshots from older computer tool responses to keep context small. + screenshot_turns: list[list[genai_types.FunctionResponse]] = [] + for content in reversed(messages): + if content.role != "user": + continue - citations_enabled = bool( - getattr(self.ctx, "scenario_enable_citations", False) if self.ctx else False - ) - if citations_enabled and not self._has_google_search_tool(): + turn_responses: list[genai_types.FunctionResponse] = [] + for part in content.parts or []: + function_response = part.function_response + if ( + function_response is not None + and function_response.parts + and function_response.name in self.tools.predefined_computer_functions + ): + turn_responses.append(function_response) + + if turn_responses: + screenshot_turns.append(turn_responses) + + for old_turn in screenshot_turns[self.max_recent_turn_with_screenshots :]: + for function_response in old_turn: + function_response.parts = None + + # Configure Gemini generation options. + tools = cast("genai_types.ToolListUnion", self.tools.params) + if self.enable_citations and not any(tool.google_search for tool in self.tools.params): tools = [*list(tools), genai_types.Tool(google_search=genai_types.GoogleSearch())] thinking_config = None if self.thinking_level is not None or self.include_thoughts: - thinking_level = ( - genai_types.ThinkingLevel(self.thinking_level.upper()) - if self.thinking_level is not None - else None - ) thinking_config = genai_types.ThinkingConfig( - thinking_level=thinking_level, + thinking_level=genai_types.ThinkingLevel(self.thinking_level.upper()) + if self.thinking_level is not None + else None, include_thoughts=self.include_thoughts, ) - # Build generate content config generate_config = genai_types.GenerateContentConfig( temperature=self.temperature, top_p=self.top_p, @@ -181,396 +144,120 @@ async def get_response(self, messages: list[genai_types.Content]) -> InferenceRe thinking_config=thinking_config, ) - # Use async API to avoid blocking the event loop - response = await self.gemini_client.aio.models.generate_content( + api_response = await self.gemini_client.aio.models.generate_content( model=self.config.model, contents=cast("Any", messages), config=generate_config, ) - - # Append assistant response (including any function_call) so that - # subsequent FunctionResponse messages correspond to a prior FunctionCall - if response.candidates and len(response.candidates) > 0 and response.candidates[0].content: - messages.append(response.candidates[0].content) - - # Process response - result = InferenceResult(content="", tool_calls=[], done=True) - collected_tool_calls: list[MCPToolCall] = [] - - if not response.candidates: - detail_parts = [] - for attr in ("prompt_feedback", "usage_metadata"): - value = getattr(response, attr, None) - if value is None: - continue - if hasattr(value, "model_dump_json"): - value_repr = value.model_dump_json() - elif hasattr(value, "model_dump"): - value_repr = repr(value.model_dump()) - else: - value_repr = repr(value) - detail_parts.append(f"{attr}={value_repr}") + if not api_response.candidates: + detail_parts: list[str] = [] + if api_response.prompt_feedback is not None: + detail_parts.append( + f"prompt_feedback={api_response.prompt_feedback.model_dump_json()}" + ) + if api_response.usage_metadata is not None: + detail_parts.append( + f"usage_metadata={api_response.usage_metadata.model_dump_json()}" + ) details = "; ".join(detail_parts) if detail_parts else "no response metadata" raise RuntimeError( f"Gemini response returned no candidates for model {self.config.model}. {details}" ) - candidate = response.candidates[0] - - # Extract text content and function calls - text_content = "" - thinking_content = "" - - if candidate.content and candidate.content.parts: - for part in candidate.content.parts: - if part.function_call: - tool_call = self._extract_tool_call(part) - if tool_call is not None: - collected_tool_calls.append(tool_call) - elif part.thought is True and part.text: - if thinking_content: - thinking_content += "\n" - thinking_content += part.text - elif part.text: - text_content += part.text - - # Assign collected tool calls and mark done status - if collected_tool_calls: - result.tool_calls = collected_tool_calls - result.done = False - - result.content = text_content - if thinking_content: - result.reasoning = thinking_content - - # Extract grounding citations from groundingMetadata - grounding_meta = getattr(candidate, "grounding_metadata", None) - if grounding_meta: - citations: list[dict[str, Any]] = [] - - # Build a lookup from chunk index → source info - chunks = getattr(grounding_meta, "grounding_chunks", None) or [] - chunk_sources: list[dict[str, Any]] = [] - for chunk in chunks: - web = getattr(chunk, "web", None) - if web: - chunk_sources.append( - { - "source": getattr(web, "uri", "") or "", - "title": getattr(web, "title", None), - } - ) - else: - chunk_sources.append({"source": "", "title": None}) - - # Walk groundingSupports for text-segment anchoring - supports = getattr(grounding_meta, "grounding_supports", None) or [] - seen_chunk_indices: set[int] = set() - for support in supports: - segment = getattr(support, "segment", None) - support_chunk_indices = getattr(support, "grounding_chunk_indices", None) or [] - segment_text = getattr(segment, "text", "") or "" if segment else "" - start_idx = getattr(segment, "start_index", None) if segment else None - end_idx = getattr(segment, "end_index", None) if segment else None - - for idx in support_chunk_indices: - seen_chunk_indices.add(idx) - source_info = chunk_sources[idx] if idx < len(chunk_sources) else {} - citations.append( - { - "type": "grounding", - "text": segment_text, - "source": source_info.get("source", ""), - "title": source_info.get("title"), - "start_index": start_idx, - "end_index": end_idx, - } - ) - - # Include any chunks not referenced by a support entry - for idx, src in enumerate(chunk_sources): - if idx not in seen_chunk_indices and src.get("source"): - citations.append( - { - "type": "grounding", - "text": "", - "source": src["source"], - "title": src.get("title"), - } - ) - - result.citations = citations + candidate = api_response.candidates[0] - return result - - def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None: - """Extract an MCPToolCall from a function call part. + # Append assistant response (including any function_call) so that + # subsequent FunctionResponse messages correspond to a prior FunctionCall + content = candidate.content + if content is not None: + messages.append(content) + + # Normalize text, thoughts, tool calls, and citations. + result = AgentResponse(content="", tool_calls=[], done=True) + text_parts: list[str] = [] + thought_parts: list[str] = [] + + parts = [] + if content is not None: + parts = content.parts or [] + for part in parts: + function_call = part.function_call + if function_call is not None: + result.tool_calls.append(self.tools.tool_call(function_call)) + result.done = False + continue - Subclasses can override to customize tool call extraction (e.g., normalizing - computer use calls to a different schema). - """ - if not part.function_call: - return None + if not part.text: + continue - func_name = part.function_call.name or "" - raw_args = dict(part.function_call.args) if part.function_call.args else {} - mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name) + if part.thought is True: + thought_parts.append(part.text) + else: + text_parts.append(part.text) - if mcp_tool_name: - return MCPToolCall( - name=mcp_tool_name, - arguments=raw_args, - ) + result.content = "".join(text_parts) + if thought_parts: + result.reasoning = "\n".join(thought_parts) - if self._computer_tool_name and func_name in gemini_tools.predefined_computer_functions: - return MCPToolCall( - name=self._computer_tool_name, - arguments=normalize_gemini_computer_use_args(func_name, raw_args), - gemini_name=func_name, # type: ignore[arg-type] - ) + grounding_meta = candidate.grounding_metadata + if grounding_meta is not None: + # TODO: Also normalize candidate.citation_metadata for URL-context citation spans. + result.citations = [ + citation.model_dump(exclude={"provider_data"}) + for citation in _grounding_citations(grounding_meta) + ] - if func_name in self._gemini_native_tools: - return MCPToolCall( - name=func_name, - arguments=raw_args, - ) + return result - return MCPToolCall( - name=func_name, - arguments=raw_args, - ) - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[genai_types.Content]: - """Format tool results into Gemini messages.""" - # Process each tool result - function_responses = [] - - for tool_call, result in zip(tool_calls, tool_results, strict=True): - # Get the Gemini function name from metadata - gemini_name = getattr(tool_call, "gemini_name", tool_call.name) - - # Convert MCP tool results to Gemini format - response_dict: dict[str, Any] = {} - is_computer_call = ( - self._computer_tool_name is not None and tool_call.name == self._computer_tool_name +def _format_content( + content: types.ContentBlock, +) -> genai_types.Part: + match content: + case types.TextContent(text=text): + return genai_types.Part(text=text) + case types.ImageContent(data=data, mimeType=mime_type): + return genai_types.Part.from_bytes( + data=base64.b64decode(data), + mime_type=mime_type or "image/png", ) - - if result.isError: - # Extract error message from content - error_msg = "Tool execution failed" - for content in result.content: - if isinstance(content, types.TextContent): - if content.text.startswith("__URL__:"): - continue - error_msg = content.text - break - response_dict["error"] = error_msg - if is_computer_call: - response_dict["url"] = self._extract_url(result) or "about:blank" - else: - # Process success content - response_dict["success"] = True - - screenshot_parts: list[genai_types.FunctionResponsePart] = [] - if is_computer_call: - url = self._extract_url(result) - for content in result.content: - if isinstance(content, types.ImageContent): - import base64 - - image_bytes = base64.b64decode(content.data) - screenshot_parts.append( - genai_types.FunctionResponsePart( - inline_data=genai_types.FunctionResponseBlob( - mime_type=content.mimeType or "image/png", - data=image_bytes, - ) - ) - ) - elif isinstance(content, types.TextContent) and content.text.startswith( - "__GEMINI_SAFETY_BLOCKED__:" - ): - response_dict.pop("success", None) - response_dict["blocked"] = True - response_dict["reason"] = content.text.replace( - "__GEMINI_SAFETY_BLOCKED__:", "", 1 - ) - - response_dict["url"] = url or "about:blank" - safety_decision = ( - tool_call.arguments.get("safety_decision") if tool_call.arguments else None + case _: + raise ValueError(f"Unknown content block type: {type(content)}") + + +def _grounding_citations( + grounding_meta: genai_types.GroundingMetadata, +) -> list[Citation]: + citations: list[Citation] = [] + chunk_sources: list[tuple[str, str | None]] = [] + for chunk in grounding_meta.grounding_chunks or []: + if chunk.web is None: + chunk_sources.append(("", None)) + else: + chunk_sources.append((chunk.web.uri or "", chunk.web.title)) + + seen_chunk_indices: set[int] = set() + for support in grounding_meta.grounding_supports or []: + segment = support.segment + segment_text = segment.text or "" if segment else "" + start_idx = segment.start_index if segment else None + end_idx = segment.end_index if segment else None + + for idx in support.grounding_chunk_indices or []: + seen_chunk_indices.add(idx) + source, title = chunk_sources[idx] if 0 <= idx < len(chunk_sources) else ("", None) + citations.append( + Citation( + type="grounding", + text=segment_text, + source=source, + title=title, + start_index=start_idx, + end_index=end_idx, ) - if safety_decision and not result.isError and not response_dict.get("blocked"): - response_dict["safety_acknowledgement"] = True - else: - # Add text content to response - for content in result.content: - if isinstance(content, types.TextContent): - response_dict["output"] = content.text - break - - # Create function response - function_response = genai_types.FunctionResponse( - name=gemini_name, - response=response_dict, - parts=screenshot_parts if screenshot_parts else None, - ) - function_responses.append(function_response) - - # Return as a user message containing all function responses - return [ - genai_types.Content( - role="user", - parts=[genai_types.Part(function_response=fr) for fr in function_responses], - ) - ] - - @staticmethod - def _extract_url(result: MCPToolResult) -> str | None: - for content in result.content: - if isinstance(content, types.TextContent) and content.text.startswith("__URL__:"): - return content.text.replace("__URL__:", "", 1) - return None - - async def call_tools( - self, tool_call: MCPToolCall | list[MCPToolCall] | None = None - ) -> list[MCPToolResult]: - """Route Gemini-owned native tool calls through provider translators.""" - return await call_agent_tools(self, self._gemini_native_tools, tool_call) - - def _map_role(self, role: str) -> str: - """Gemini uses 'model' instead of 'assistant' for non-user turns.""" - if role == "assistant": - return "model" - return role - - async def create_user_message(self, text: str) -> genai_types.Content: - """Create a user message in Gemini's format.""" - return genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) - - def _has_google_search_tool(self) -> bool: - """Check if google_search is already in the tool list.""" - return any(getattr(tool, "google_search", None) is not None for tool in self.gemini_tools) - - def _convert_tools_for_gemini(self) -> None: - """Convert MCP tools to Gemini tool format.""" - self._gemini_to_mcp_tool_map = {} - self._computer_tool_name = None - self._gemini_native_tools = {} - self.gemini_tools = [] - - categorized = self._categorized_tools - - capabilities = self._discover_environment_capabilities(self.get_available_tools()) - self._environment_capabilities = capabilities - provider_backing_tools: set[str] = set() - - for capability in capabilities.values(): - if capability.name not in gemini_tools.capabilities: - continue - for gemini_tool in gemini_tools.tools_for_capability(capability, self.model): - provider_backing_tools.add(gemini_tool.env_tool_name) - if isinstance(gemini_tool, GeminiComputerTool): - self._computer_tool_name = gemini_tool.name - self._gemini_native_tools[gemini_tool.name] = gemini_tool - gemini_tool.excluded_predefined_functions = ( - self._computer_use_excluded_function_names(gemini_tool.env_tool_name) - ) - self.gemini_tools.append(gemini_tool.to_params()) - continue - - self._gemini_native_tools[gemini_tool.name] = gemini_tool - self.gemini_tools.append(gemini_tool.to_params()) - - configured_hosted = select_hosted_tools( - self.config.hosted_tools, - tool_type=GeminiHostedTool, - model=self.model, - ) - self.gemini_tools.extend(tool.to_params() for tool in configured_hosted) - - # Process generic function tools - for tool in categorized.generic: - if tool.name in provider_backing_tools: - continue - gemini_tool = self._to_gemini_tool(tool) - if gemini_tool: - self._gemini_to_mcp_tool_map[tool.name] = tool.name - self.gemini_tools.append(gemini_tool) - - # Log actual tools being used - tool_names = sorted( - { - *self._gemini_to_mcp_tool_map.keys(), - *self._gemini_native_tools.keys(), - } - ) - self.console.info( - f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}" - ) - - def _computer_use_excluded_function_names(self, computer_tool_name: str) -> list[str]: - excluded = [ - *self.excluded_predefined_functions, - *self._colliding_predefined_function_names(computer_tool_name), - ] - return sorted(set(excluded)) - - def _colliding_predefined_function_names(self, computer_tool_name: str) -> list[str]: - """Exclude predefined computer actions shadowed by generic MCP tools.""" - generic_names = { - tool.name for tool in self._categorized_tools.generic if tool.name != computer_tool_name - } - return sorted(set(gemini_tools.predefined_computer_functions) & generic_names) - - def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None: - """Drop older Gemini Computer Use screenshots to keep context growth bounded.""" - if self._computer_tool_name is None: - return - - turn_with_screenshots_found = 0 - for content in reversed(messages): - if content.role != "user" or not content.parts: - continue - - has_screenshot = any( - part.function_response - and part.function_response.parts - and part.function_response.name in gemini_tools.predefined_computer_functions - for part in content.parts ) - if not has_screenshot: - continue - turn_with_screenshots_found += 1 - if turn_with_screenshots_found <= self.max_recent_turn_with_screenshots: - continue - - for part in content.parts: - if ( - part.function_response - and part.function_response.parts - and part.function_response.name in gemini_tools.predefined_computer_functions - ): - part.function_response.parts = None - - def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None: - """Convert a single MCP tool to Gemini function tool format. - - Args: - tool: MCP tool to convert - - Returns: - Gemini Tool with function declaration - """ - if tool.description is None or tool.inputSchema is None: - raise ValueError(f"MCP tool {tool.name} requires both a description and inputSchema.") - - function_decl = genai_types.FunctionDeclaration( - name=tool.name, - description=tool.description, - parameters_json_schema=tool.inputSchema, - ) - return genai_types.Tool(function_declarations=[function_decl]) + for idx, (source, title) in enumerate(chunk_sources): + if idx not in seen_chunk_indices and source: + citations.append(Citation(type="grounding", text="", source=source, title=title)) + return citations diff --git a/hud/agents/gemini/settings.py b/hud/agents/gemini/settings.py new file mode 100644 index 000000000..2a7c89b6e --- /dev/null +++ b/hud/agents/gemini/settings.py @@ -0,0 +1,21 @@ +"""Gemini agent settings.""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class GeminiAgentSettings(BaseSettings): + """Gemini provider defaults owned by the agent.""" + + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="allow") + + MAX_RECENT_TURN_WITH_SCREENSHOTS: int = Field( + default=3, + description="Maximum number of recent turns to keep screenshots for in Gemini agent", + validation_alias="GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS", + ) + + +gemini_agent_settings = GeminiAgentSettings() diff --git a/hud/agents/gemini/tools/__init__.py b/hud/agents/gemini/tools/__init__.py index 33c31d9ea..ba9583915 100644 --- a/hud/agents/gemini/tools/__init__.py +++ b/hud/agents/gemini/tools/__init__.py @@ -2,30 +2,20 @@ from __future__ import annotations -from dataclasses import dataclass, field - -from hud.agents.tools import AgentToolRegistry - -from .base import GeminiTool -from .coding import ( - GEMINI_EDIT_SPEC, - GEMINI_SHELL_SPEC, - GEMINI_WRITE_SPEC, - GeminiEditTool, - GeminiShellTool, - GeminiWriteTool, -) +from typing import TYPE_CHECKING, ClassVar + +from google.genai import types as genai_types + +from hud.agents.tools import AgentTool, AgentTools +from hud.types import MCPToolCall + +from .base import GeminiFunctionTool +from .coding import GeminiEditTool, GeminiShellTool, GeminiWriteTool from .computer import ( - GEMINI_COMPUTER_SPEC, PREDEFINED_COMPUTER_USE_FUNCTIONS, GeminiComputerTool, - normalize_gemini_computer_use_args, ) from .filesystem import ( - GEMINI_GLOB_SPEC, - GEMINI_LIST_SPEC, - GEMINI_READ_SPEC, - GEMINI_SEARCH_SPEC, GeminiGlobTool, GeminiListTool, GeminiReadTool, @@ -37,14 +27,20 @@ GeminiHostedTool, GeminiUrlContextTool, ) -from .memory import GEMINI_MEMORY_SPEC, GeminiMemoryTool +from .memory import GeminiMemoryTool + +if TYPE_CHECKING: + from collections.abc import Mapping + + import mcp.types as types + + from hud.agents.tools import ToolMetadata -@dataclass(frozen=True) -class GeminiToolRegistry(AgentToolRegistry[GeminiTool]): - """Registry for Gemini harness tools.""" +class GeminiAgentTools(AgentTools[AgentTool[genai_types.Tool], genai_types.Tool]): + """Prepared Gemini tool state for a run.""" - tool_classes: tuple[type[GeminiTool], ...] = ( + native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = ( GeminiComputerTool, GeminiShellTool, GeminiEditTool, @@ -55,52 +51,79 @@ class GeminiToolRegistry(AgentToolRegistry[GeminiTool]): GeminiListTool, GeminiMemoryTool, ) - name_fallbacks: dict[str, tuple[str, ...]] = field( - default_factory=lambda: { - "computer": ("computer", "gemini_computer", "computer_gemini"), - "shell": ("bash",), - "editor": ("edit",), - "filesystem": ("read", "grep", "glob", "list"), - "memory": ("memory",), - } - ) + function_tool_class = GeminiFunctionTool + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = { + "computer": ("computer", "gemini_computer", "computer_gemini"), + "shell": ("bash",), + "editor": ("edit",), + "filesystem": ("read", "grep", "glob", "list"), + "memory": ("memory",), + } + + def __init__(self, *, excluded_predefined_functions: list[str] | None = None) -> None: + super().__init__() + self.excluded_predefined_functions = list(excluded_predefined_functions or []) @property - def api_types(self) -> frozenset[str]: - return frozenset(cls.name for cls in self.tool_classes) + def computer_tool_name(self) -> str | None: + return "computer_use" if "computer_use" in self else None @property def predefined_computer_functions(self) -> frozenset[str]: return frozenset(PREDEFINED_COMPUTER_USE_FUNCTIONS) + def tool_call(self, function_call: genai_types.FunctionCall) -> MCPToolCall: + name = function_call.name or "" + arguments = dict(function_call.args) if function_call.args else {} + + if mcp_tool_name := self.name_map.get(name): + return MCPToolCall(name=mcp_tool_name, arguments=arguments) + + if self.computer_tool_name and name in self.predefined_computer_functions: + computer_tool = self.get(self.computer_tool_name) + if isinstance(computer_tool, GeminiComputerTool): + return computer_tool.tool_call(name, arguments) + + return MCPToolCall(name=name, arguments=arguments) + + def select_tools( + self, + tools: list[types.Tool], + model: str, + *, + tool_metadata: ToolMetadata | None = None, + excluded_predefined_functions: list[str] | None = None, + ) -> tuple[list[AgentTool[genai_types.Tool]], list[types.Tool]]: + provider_tools, user_tools = super().select_tools( + tools, + model, + tool_metadata=tool_metadata, + ) + user_tool_names = {tool.name for tool in user_tools} + configured_exclusions = ( + excluded_predefined_functions + if excluded_predefined_functions is not None + else self.excluded_predefined_functions + ) + colliding_exclusions = sorted(self.predefined_computer_functions & user_tool_names) + exclusions = sorted({*configured_exclusions, *colliding_exclusions}) + if not exclusions: + return provider_tools, user_tools + return ( + [ + tool.with_excluded_predefined_functions(exclusions) + if isinstance(tool, GeminiComputerTool) + else tool + for tool in provider_tools + ], + user_tools, + ) -gemini_tools = GeminiToolRegistry() __all__ = [ - "GEMINI_COMPUTER_SPEC", - "GEMINI_EDIT_SPEC", - "GEMINI_GLOB_SPEC", - "GEMINI_LIST_SPEC", - "GEMINI_MEMORY_SPEC", - "GEMINI_READ_SPEC", - "GEMINI_SEARCH_SPEC", - "GEMINI_SHELL_SPEC", - "GEMINI_WRITE_SPEC", + "GeminiAgentTools", "GeminiCodeExecutionTool", - "GeminiComputerTool", - "GeminiEditTool", - "GeminiGlobTool", "GeminiGoogleSearchTool", "GeminiHostedTool", - "GeminiListTool", - "GeminiMemoryTool", - "GeminiReadTool", - "GeminiSearchTool", - "GeminiShellTool", - "GeminiTool", - "GeminiToolRegistry", "GeminiUrlContextTool", - "GeminiWriteTool", - "gemini_tools", - "normalize_gemini_computer_use_args", ] diff --git a/hud/agents/gemini/tools/base.py b/hud/agents/gemini/tools/base.py index 6d8612ca8..a52081d4a 100644 --- a/hud/agents/gemini/tools/base.py +++ b/hud/agents/gemini/tools/base.py @@ -2,20 +2,20 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar +import mcp.types as types from google.genai import types as genai_types -from hud.agents.tools import AgentTool, AgentToolSpec, CallTool, call_tool - -GeminiToolSpec = AgentToolSpec +from hud.agents.tools import AgentTool, AgentToolSpec +if TYPE_CHECKING: + from hud.types import MCPToolCall, MCPToolResult -class GeminiTool(AgentTool[Any]): - """Gemini provider tool backed by an environment tool.""" +GeminiToolSpec = AgentToolSpec -class GeminiFunctionTool(GeminiTool): +class GeminiTool(AgentTool[genai_types.Tool]): """Gemini function declaration backed by an environment tool.""" description: ClassVar[str] @@ -25,12 +25,77 @@ def to_params(self) -> genai_types.Tool: return genai_types.Tool( function_declarations=[ genai_types.FunctionDeclaration( - name=self.name, + name=self.provider_name, description=self.description, parameters_json_schema=self.parameters, ) ] ) + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> genai_types.Content: + text = next( + (content.text for content in result.content if isinstance(content, types.TextContent)), + None, + ) + response: dict[str, Any] = ( + {"error": text or "Tool execution failed"} if result.isError else {"success": True} + ) + if text is not None and not result.isError: + response["output"] = text + return genai_types.Content( + role="user", + parts=[ + genai_types.Part( + function_response=genai_types.FunctionResponse( + name=call.provider_name or call.name, + response=response, + ) + ) + ], + ) + + +class GeminiFunctionTool(GeminiTool): + """Regular environment tool exposed as a Gemini function declaration.""" + + name = "function" + capability = "function" + + def __init__( + self, + *, + env_tool_name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + super().__init__( + env_tool_name=env_tool_name, + spec=GeminiToolSpec(api_type="function", api_name=env_tool_name), + ) + self._description = description + self._parameters = parameters + + @classmethod + def from_tool(cls, tool: types.Tool) -> GeminiFunctionTool: + if tool.description is None: + raise ValueError(f"MCP tool {tool.name} requires a description.") + return cls( + env_tool_name=tool.name, + description=tool.description, + parameters=tool.inputSchema, + ) + + @property + def provider_name(self) -> str: + return self.env_tool_name -__all__ = ["CallTool", "GeminiFunctionTool", "GeminiTool", "GeminiToolSpec", "call_tool"] + def to_params(self) -> genai_types.Tool: + return genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + name=self.provider_name, + description=self._description, + parameters_json_schema=self._parameters, + ) + ] + ) diff --git a/hud/agents/gemini/tools/coding.py b/hud/agents/gemini/tools/coding.py index 6817764f3..50a2eec1b 100644 --- a/hud/agents/gemini/tools/coding.py +++ b/hud/agents/gemini/tools/coding.py @@ -6,16 +6,17 @@ from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: + from hud.agents.tools.base import CallTool from hud.types import MCPToolResult -from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool +from .base import GeminiTool, GeminiToolSpec GEMINI_SHELL_SPEC = GeminiToolSpec(api_type="run_shell_command", api_name="run_shell_command") GEMINI_EDIT_SPEC = GeminiToolSpec(api_type="replace", api_name="replace") GEMINI_WRITE_SPEC = GeminiToolSpec(api_type="write_file", api_name="write_file") -class GeminiShellTool(GeminiFunctionTool): +class GeminiShellTool(GeminiTool): """Translate Gemini CLI shell calls into the generic bash env primitive.""" name = "run_shell_command" @@ -39,17 +40,17 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_SHELL_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: command = arguments.get("command") if not isinstance(command, str) or not command: raise ValueError("command is required") dir_path = arguments.get("dir_path") if isinstance(dir_path, str) and dir_path: command = f"cd {shlex.quote(dir_path)} && {command}" - return await call_tool(caller, self.env_tool_name, {"command": command}) + return await super().execute(call_tool, {"command": command}) -class GeminiEditTool(GeminiFunctionTool): +class GeminiEditTool(GeminiTool): """Translate Gemini CLI replace calls into the generic edit env primitive.""" name = "replace" @@ -74,19 +75,21 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_EDIT_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: file_path = _required_str(arguments, "file_path") old_string = arguments.get("old_string") new_string = arguments.get("new_string") if old_string == "": - return await call_tool( - caller, - self.env_tool_name, - {"command": "create", "path": file_path, "file_text": new_string or ""}, + return await super().execute( + call_tool, + { + "command": "create", + "path": file_path, + "file_text": new_string or "", + }, ) - return await call_tool( - caller, - self.env_tool_name, + return await super().execute( + call_tool, { "command": "replace", "path": file_path, @@ -96,7 +99,7 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR ) -class GeminiWriteTool(GeminiFunctionTool): +class GeminiWriteTool(GeminiTool): """Translate Gemini CLI write_file calls into the generic edit env primitive.""" name = "write_file" @@ -116,10 +119,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_WRITE_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - return await call_tool( - caller, - self.env_tool_name, + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + return await super().execute( + call_tool, { "command": "write", "path": _required_str(arguments, "file_path"), @@ -133,13 +135,3 @@ def _required_str(arguments: dict[str, Any], key: str) -> str: if not isinstance(value, str) or not value: raise ValueError(f"{key} is required") return value - - -__all__ = [ - "GEMINI_EDIT_SPEC", - "GEMINI_SHELL_SPEC", - "GEMINI_WRITE_SPEC", - "GeminiEditTool", - "GeminiShellTool", - "GeminiWriteTool", -] diff --git a/hud/agents/gemini/tools/computer.py b/hud/agents/gemini/tools/computer.py index b52680e49..cf8684c68 100644 --- a/hud/agents/gemini/tools/computer.py +++ b/hud/agents/gemini/tools/computer.py @@ -2,18 +2,21 @@ from __future__ import annotations +import base64 import platform -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from google.genai import types as genai_types from mcp.types import ImageContent, TextContent -from hud.types import MCPToolResult +from hud.agents.tools import AgentTool +from hud.agents.tools.computer import computer_error_result, execute_computer_calls +from hud.types import MCPToolCall, MCPToolResult -from .base import CallTool, GeminiTool, GeminiToolSpec, call_tool +from .base import GeminiToolSpec if TYPE_CHECKING: - from hud.agents.tools import EnvironmentCapability + from hud.agents.tools.base import CallTool SUPPORTED_GEMINI_COMPUTER_USE_MODELS = ( "gemini-2.5-computer-use-preview-10-2025", @@ -22,6 +25,7 @@ GEMINI_COORDINATE_SPACE = 1000 GEMINI_DRAG_INSET = 25 +IS_MAC = platform.system().lower() == "darwin" PREDEFINED_COMPUTER_USE_FUNCTIONS = ( "open_web_browser", @@ -38,6 +42,8 @@ "key_combination", "drag_and_drop", ) +GEMINI_URL_PREFIX = "__URL__:" +GEMINI_SAFETY_BLOCKED_PREFIX = "__GEMINI_SAFETY_BLOCKED__:" GEMINI_COMPUTER_SPEC = GeminiToolSpec( api_type="computer_use", @@ -46,51 +52,7 @@ ) -def normalize_gemini_computer_use_args(action: str, raw_args: dict[str, Any]) -> dict[str, Any]: - """Normalize Gemini Computer Use function-call args to agent-tool args.""" - normalized_args: dict[str, Any] = {"action": action} - - coord = raw_args.get("coordinate") or raw_args.get("coordinates") - if isinstance(coord, list | tuple) and len(coord) >= 2: - try: - normalized_args["x"] = int(coord[0]) - normalized_args["y"] = int(coord[1]) - except (TypeError, ValueError): - pass - - dest = ( - raw_args.get("destination") - or raw_args.get("destination_coordinate") - or raw_args.get("destinationCoordinate") - ) - if isinstance(dest, list | tuple) and len(dest) >= 2: - try: - normalized_args["destination_x"] = int(dest[0]) - normalized_args["destination_y"] = int(dest[1]) - except (TypeError, ValueError): - pass - - for key in ( - "text", - "press_enter", - "clear_before_typing", - "safety_decision", - "direction", - "magnitude", - "url", - "keys", - "x", - "y", - "destination_x", - "destination_y", - ): - if key in raw_args: - normalized_args[key] = raw_args[key] - - return normalized_args - - -class GeminiComputerTool(GeminiTool): +class GeminiComputerTool(AgentTool[genai_types.Tool]): """Translate Gemini Computer Use calls into generic environment computer calls.""" name = "computer_use" @@ -102,19 +64,24 @@ def default_spec(cls, model: str) -> GeminiToolSpec | None: return GEMINI_COMPUTER_SPEC return None - @classmethod - def from_capability( - cls, - capability: EnvironmentCapability, + def __init__( + self, + *, + env_tool_name: str, spec: GeminiToolSpec, - model: str, - ) -> GeminiComputerTool: - del model - return cls(env_tool_name=capability.tool_name, spec=spec) - - def __init__(self, *, env_tool_name: str, spec: GeminiToolSpec) -> None: + excluded_predefined_functions: list[str] | None = None, + ) -> None: super().__init__(env_tool_name=env_tool_name, spec=spec) - self.excluded_predefined_functions: list[str] = [] + self.excluded_predefined_functions = excluded_predefined_functions or [] + + def with_excluded_predefined_functions( + self, excluded_predefined_functions: list[str] + ) -> GeminiComputerTool: + return GeminiComputerTool( + env_tool_name=self.env_tool_name, + spec=self.spec, + excluded_predefined_functions=excluded_predefined_functions, + ) def to_params(self) -> genai_types.Tool: return genai_types.Tool( @@ -124,31 +91,100 @@ def to_params(self) -> genai_types.Tool: ) ) - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + def tool_call(self, function_name: str, raw_args: dict[str, Any]) -> MCPToolCall: + return MCPToolCall( + name=self.name, + arguments={"action": function_name, **raw_args}, + provider_name=function_name, + ) + + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> genai_types.Content: + text = next( + ( + content.text + for content in result.content + if isinstance(content, TextContent) + and not content.text.startswith(GEMINI_URL_PREFIX) + ), + None, + ) + response: dict[str, Any] = ( + {"error": text or "Tool execution failed"} if result.isError else {"success": True} + ) + if text is not None and not result.isError: + response["output"] = text + + url = None + parts: list[genai_types.FunctionResponsePart] = [] + for content in result.content: + match content: + case ImageContent(data=data, mimeType=mime_type): + parts.append( + genai_types.FunctionResponsePart( + inline_data=genai_types.FunctionResponseBlob( + mime_type=mime_type or "image/png", + data=base64.b64decode(data), + ) + ) + ) + case TextContent(text=text) if text.startswith(GEMINI_URL_PREFIX): + url = text.removeprefix(GEMINI_URL_PREFIX) + case TextContent(text=text) if text.startswith(GEMINI_SAFETY_BLOCKED_PREFIX): + response.pop("success", None) + response["blocked"] = True + response["reason"] = text.removeprefix(GEMINI_SAFETY_BLOCKED_PREFIX) + case _: + continue + + response["url"] = url or "about:blank" + safety_decision = call.arguments.get("safety_decision") if call.arguments else None + if safety_decision and not result.isError and not response.get("blocked"): + response["safety_acknowledgement"] = True + + return genai_types.Content( + role="user", + parts=[ + genai_types.Part( + function_response=genai_types.FunctionResponse( + name=call.provider_name or call.name, + response=response, + parts=parts or None, + ) + ) + ], + ) + + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: action = arguments.get("action") if not isinstance(action, str): - return _error_result("action is required") - if _requires_confirmation(arguments.get("safety_decision")): - return _blocked_result( - "Gemini Computer Use action requires user confirmation before execution." + return computer_error_result("action is required") + safety_decision = arguments.get("safety_decision") + if ( + isinstance(safety_decision, dict) + and cast("dict[str, Any]", safety_decision).get("decision") == "require_confirmation" + ): + return MCPToolResult( + content=[ + TextContent( + type="text", + text=( + f"{GEMINI_SAFETY_BLOCKED_PREFIX}" + "Gemini Computer Use action requires user confirmation before " + "execution." + ), + ) + ], + isError=False, ) - result = MCPToolResult(content=[], isError=False) - for call in self._env_calls(action, arguments): - result = await call_tool(caller, self.env_tool_name, call) - if result.isError: - return result - - if action != "open_web_browser" and not _has_image(result): - screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"}) - if not screenshot.isError and screenshot.content: - result = MCPToolResult( - content=[*result.content, *screenshot.content], - isError=result.isError, - ) - return result + return await execute_computer_calls( + call_tool, + env_tool_name=self.env_tool_name, + calls=self._computer_actions(action, arguments), + ensure_screenshot=action != "open_web_browser", + ) - def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: + def _computer_actions(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: if action == "open_web_browser": return [{"action": "screenshot"}] if action == "click_at": @@ -165,7 +201,12 @@ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, A ] ) if arguments.get("clear_before_typing", True): - calls.extend(_clear_text_calls()) + calls.extend( + [ + {"action": "press", "keys": ["cmd", "a"] if IS_MAC else ["ctrl", "a"]}, + {"action": "press", "keys": ["backspace" if IS_MAC else "delete"]}, + ] + ) calls.append( { "action": "write", @@ -175,133 +216,77 @@ def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, A ) return calls if action in ("scroll_document", "scroll_at"): - call = _scroll_call(arguments) + direction = arguments.get("direction") + magnitude = arguments.get("magnitude") or 800 + if direction == "down": + call = {"action": "scroll", "scroll_x": None, "scroll_y": magnitude} + elif direction == "up": + call = {"action": "scroll", "scroll_x": None, "scroll_y": -magnitude} + elif direction == "right": + call = {"action": "scroll", "scroll_x": magnitude, "scroll_y": None} + elif direction == "left": + call = {"action": "scroll", "scroll_x": -magnitude, "scroll_y": None} + else: + raise ValueError("direction must be one of up, down, left, right") if action == "scroll_at": call.update({"x": arguments.get("x"), "y": arguments.get("y")}) return [call] if action == "wait_5_seconds": return [{"action": "wait", "time": 5000}] if action == "go_back": - return [{"action": "press", "keys": ["cmd", "["] if _is_mac() else ["alt", "left"]}] + return [{"action": "press", "keys": ["cmd", "["] if IS_MAC else ["alt", "left"]}] if action == "go_forward": - return [{"action": "press", "keys": ["cmd", "]"] if _is_mac() else ["alt", "right"]}] + return [{"action": "press", "keys": ["cmd", "]"] if IS_MAC else ["alt", "right"]}] if action == "search": target = arguments.get("url") or "https://www.google.com" - return [*_address_bar_calls(), {"action": "write", "text": target, "enter_after": True}] + return [ + {"action": "press", "keys": ["cmd", "l"] if IS_MAC else ["ctrl", "l"]}, + {"action": "write", "text": target, "enter_after": True}, + ] if action == "navigate": return [ - *_address_bar_calls(), + {"action": "press", "keys": ["cmd", "l"] if IS_MAC else ["ctrl", "l"]}, {"action": "write", "text": arguments.get("url"), "enter_after": True}, ] if action == "key_combination": - return [{"action": "press", "keys": _normalize_key_combination(arguments.get("keys"))}] + keys = arguments.get("keys") + if not isinstance(keys, str): + raise ValueError("keys must be a '+'-separated string") + aliases = { + "control": "ctrl", + "cmd": "cmd", + "command": "cmd", + "meta": "cmd" if IS_MAC else "ctrl", + "return": "enter", + } + normalized_keys = [ + aliases.get(key, key) for part in keys.split("+") if (key := part.strip().lower()) + ] + return [{"action": "press", "keys": normalized_keys}] if action == "drag_and_drop": + max_drag_coordinate = max( + GEMINI_COORDINATE_SPACE - GEMINI_DRAG_INSET, + GEMINI_DRAG_INSET, + ) + + def drag_coordinate(value: Any) -> Any: + if not isinstance(value, int | float) or not 0 <= value <= GEMINI_COORDINATE_SPACE: + return value + return min(max(int(value), GEMINI_DRAG_INSET), max_drag_coordinate) + return [ { "action": "drag", "path": [ { - "x": _inset_drag_coordinate(arguments.get("x")), - "y": _inset_drag_coordinate(arguments.get("y")), + "x": drag_coordinate(arguments.get("x")), + "y": drag_coordinate(arguments.get("y")), }, { - "x": _inset_drag_coordinate(arguments.get("destination_x")), - "y": _inset_drag_coordinate(arguments.get("destination_y")), + "x": drag_coordinate(arguments.get("destination_x")), + "y": drag_coordinate(arguments.get("destination_y")), }, ], } ] raise ValueError(f"Unknown Gemini computer action: {action}") - - -def _scroll_call(arguments: dict[str, Any]) -> dict[str, Any]: - direction = arguments.get("direction") - magnitude = arguments.get("magnitude") or 800 - if direction == "down": - return {"action": "scroll", "scroll_x": None, "scroll_y": magnitude} - if direction == "up": - return {"action": "scroll", "scroll_x": None, "scroll_y": -magnitude} - if direction == "right": - return {"action": "scroll", "scroll_x": magnitude, "scroll_y": None} - if direction == "left": - return {"action": "scroll", "scroll_x": -magnitude, "scroll_y": None} - raise ValueError("direction must be one of up, down, left, right") - - -def _inset_drag_coordinate(value: Any) -> Any: - """Keep Gemini normalized drag endpoints away from display edges.""" - if not isinstance(value, int | float) or not 0 <= value <= GEMINI_COORDINATE_SPACE: - return value - max_value = max(GEMINI_COORDINATE_SPACE - GEMINI_DRAG_INSET, GEMINI_DRAG_INSET) - return min(max(int(value), GEMINI_DRAG_INSET), max_value) - - -def _clear_text_calls() -> list[dict[str, Any]]: - is_mac = _is_mac() - return [ - {"action": "press", "keys": ["cmd", "a"] if is_mac else ["ctrl", "a"]}, - {"action": "press", "keys": ["backspace" if is_mac else "delete"]}, - ] - - -def _normalize_key_combination(keys: Any) -> list[str] | Any: - if isinstance(keys, str): - return [_normalize_key(key) for key in keys.split("+") if key.strip()] - if isinstance(keys, list): - return [_normalize_key(key) if isinstance(key, str) else key for key in keys] - return keys - - -def _normalize_key(key: str) -> str: - normalized = key.strip().lower() - aliases = { - "control": "ctrl", - "cmd": "cmd", - "command": "cmd", - "meta": "cmd" if _is_mac() else "ctrl", - "return": "enter", - } - return aliases.get(normalized, normalized) - - -def _requires_confirmation(safety_decision: Any) -> bool: - if not isinstance(safety_decision, dict): - return False - return safety_decision.get("decision") == "require_confirmation" - - -def _address_bar_calls() -> list[dict[str, Any]]: - return [{"action": "press", "keys": ["cmd", "l"] if _is_mac() else ["ctrl", "l"]}] - - -def _is_mac() -> bool: - return platform.system().lower() == "darwin" - - -def _has_image(result: MCPToolResult) -> bool: - return any(isinstance(block, ImageContent) for block in result.content) - - -def _error_result(message: str) -> MCPToolResult: - return MCPToolResult( - content=[TextContent(type="text", text=message)], - isError=True, - ) - - -def _blocked_result(message: str) -> MCPToolResult: - return MCPToolResult( - content=[TextContent(type="text", text=f"__GEMINI_SAFETY_BLOCKED__:{message}")], - isError=False, - ) - - -__all__ = [ - "GEMINI_COMPUTER_SPEC", - "GEMINI_COORDINATE_SPACE", - "GEMINI_DRAG_INSET", - "PREDEFINED_COMPUTER_USE_FUNCTIONS", - "SUPPORTED_GEMINI_COMPUTER_USE_MODELS", - "GeminiComputerTool", - "normalize_gemini_computer_use_args", -] diff --git a/hud/agents/gemini/tools/filesystem.py b/hud/agents/gemini/tools/filesystem.py index dc4750ee8..8ba89bd39 100644 --- a/hud/agents/gemini/tools/filesystem.py +++ b/hud/agents/gemini/tools/filesystem.py @@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: + from hud.agents.tools.base import CallTool from hud.types import MCPToolResult from hud.agents.tools import GroupedCapabilityMixin -from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool +from .base import GeminiTool, GeminiToolSpec GEMINI_READ_SPEC = GeminiToolSpec(api_type="read_file", api_name="read_file") GEMINI_SEARCH_SPEC = GeminiToolSpec(api_type="grep_search", api_name="grep_search") @@ -17,7 +18,7 @@ GEMINI_LIST_SPEC = GeminiToolSpec(api_type="list_directory", api_name="list_directory") -class GeminiFilesystemTool(GroupedCapabilityMixin, GeminiFunctionTool): +class GeminiFilesystemTool(GroupedCapabilityMixin, GeminiTool): """Gemini function tool backed by one filesystem environment primitive.""" capability = "filesystem" @@ -45,16 +46,15 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_READ_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: start = arguments.get("start_line") end = arguments.get("end_line") offset = int(start) - 1 if isinstance(start, int) and start > 0 else None limit = None if offset is not None and isinstance(start, int) and isinstance(end, int) and end >= start: limit = end - start + 1 - return await call_tool( - caller, - self.env_tool_name, + return await super().execute( + call_tool, { "filePath": _required_str(arguments, "file_path"), "offset": offset, @@ -84,10 +84,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_SEARCH_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - return await call_tool( - caller, - self.env_tool_name, + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + return await super().execute( + call_tool, { "pattern": _required_str(arguments, "pattern"), "path": arguments.get("dir_path"), @@ -120,10 +119,9 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_GLOB_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - return await call_tool( - caller, - self.env_tool_name, + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + return await super().execute( + call_tool, { "pattern": _required_str(arguments, "pattern"), "path": arguments.get("dir_path"), @@ -156,11 +154,13 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_LIST_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - return await call_tool( - caller, - self.env_tool_name, - {"path": _required_str(arguments, "dir_path"), "ignore": arguments.get("ignore")}, + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + return await super().execute( + call_tool, + { + "path": _required_str(arguments, "dir_path"), + "ignore": arguments.get("ignore"), + }, ) @@ -169,16 +169,3 @@ def _required_str(arguments: dict[str, Any], key: str) -> str: if not isinstance(value, str) or not value: raise ValueError(f"{key} is required") return value - - -__all__ = [ - "GEMINI_GLOB_SPEC", - "GEMINI_LIST_SPEC", - "GEMINI_READ_SPEC", - "GEMINI_SEARCH_SPEC", - "GeminiFilesystemTool", - "GeminiGlobTool", - "GeminiListTool", - "GeminiReadTool", - "GeminiSearchTool", -] diff --git a/hud/agents/gemini/tools/hosted.py b/hud/agents/gemini/tools/hosted.py index 25a993a7d..138e1d4de 100644 --- a/hud/agents/gemini/tools/hosted.py +++ b/hud/agents/gemini/tools/hosted.py @@ -40,11 +40,3 @@ class GeminiCodeExecutionTool(GeminiHostedTool): def to_params(self) -> genai_types.Tool: return genai_types.Tool(code_execution=genai_types.ToolCodeExecution()) - - -__all__ = [ - "GeminiCodeExecutionTool", - "GeminiGoogleSearchTool", - "GeminiHostedTool", - "GeminiUrlContextTool", -] diff --git a/hud/agents/gemini/tools/memory.py b/hud/agents/gemini/tools/memory.py index 8aeb14e50..8d91dc2fb 100644 --- a/hud/agents/gemini/tools/memory.py +++ b/hud/agents/gemini/tools/memory.py @@ -6,14 +6,15 @@ from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: + from hud.agents.tools.base import CallTool from hud.types import MCPToolResult -from .base import CallTool, GeminiFunctionTool, GeminiToolSpec, call_tool +from .base import GeminiTool, GeminiToolSpec GEMINI_MEMORY_SPEC = GeminiToolSpec(api_type="save_memory", api_name="save_memory") -class GeminiMemoryTool(GeminiFunctionTool): +class GeminiMemoryTool(GeminiTool): """Translate Gemini save_memory calls into the file-backed memory env primitive.""" name = "save_memory" @@ -32,21 +33,17 @@ def default_spec(cls, model: str) -> GeminiToolSpec: del model return GEMINI_MEMORY_SPEC - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: fact = arguments.get("fact") if not isinstance(fact, str) or not fact.strip(): raise ValueError("fact is required") text = fact.strip() digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:12] - return await call_tool( - caller, - self.env_tool_name, + return await super().execute( + call_tool, { "command": "create", "path": f"/memories/gemini-{digest}.md", "file_text": f"{text}\n", }, ) - - -__all__ = ["GEMINI_MEMORY_SPEC", "GeminiMemoryTool"] diff --git a/hud/agents/misc/__init__.py b/hud/agents/misc/__init__.py index 522faac53..8a048c64d 100644 --- a/hud/agents/misc/__init__.py +++ b/hud/agents/misc/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from .response_agent import ResponseAgent +from .response_automation import auto_respond -__all__ = ["ResponseAgent"] +__all__ = ["auto_respond"] diff --git a/hud/agents/misc/response_agent.py b/hud/agents/misc/response_agent.py deleted file mode 100644 index 52f9bfde4..000000000 --- a/hud/agents/misc/response_agent.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Literal - -from openai import AsyncOpenAI -from openai.types.responses import ResponseOutputText - -from hud.settings import settings -from hud.telemetry import instrument - -logger = logging.getLogger(__name__) - -ResponseType = Literal["STOP", "CONTINUE"] - -DEFAULT_SYSTEM_PROMPT = """\ -You are an assistant that helps determine the appropriate response to an agent's message. - -You will receive messages from an agent that is performing tasks for a user. -Your job is to analyze these messages and respond with one of the following: - -- STOP: If the agent indicates it has successfully completed a task or is stuck, - struggling or says it cannot complete the task, even if phrased as a question - like "I have entered the right values into this form. Would you like me to do - anything else?" or "Here is the website. Is there any other information you - need?" or if the agent has strongly determined it wants to stop the task like - "The task is infeasible. Can I help you with something else?" - -- CONTINUE: If the agent is asking for clarification before proceeding with a task - like "I'm about to clear cookies from this website. Would you like me to proceed?" - or "I've entered the right values into this form. Would you like me to continue - with the rest of the task?" - -Respond ONLY with one of these two options.""" - - -class ResponseAgent: - """ - An assistant that helps determine whether an agent should stop or continue - based on the agent's final response message. - """ - - def __init__( - self, - model: str = "gpt-5", - system_prompt: str | None = None, - ) -> None: - """ - Initialize the ResponseAgent. - - Args: - model: The model to use via HUD inference gateway (default: "gpt-5"). - Supports any model available through inference.hud.ai. - system_prompt: Optional custom system prompt for determining responses. - """ - api_key = settings.api_key - if not api_key: - raise ValueError( - "HUD API key is required for auto_respond. Set HUD_API_KEY environment variable." - ) - - self.client: AsyncOpenAI = AsyncOpenAI( - base_url=settings.hud_gateway_url, - api_key=api_key, - ) - self.model = model - self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT - - @instrument( - category="agent", - name="response_agent", - internal_type="user-message", - ) - async def determine_response(self, agent_message: str) -> ResponseType: - """ - Determine whether the agent should stop or continue based on its message. - - Args: - agent_message: The message from the agent - - Returns: - ResponseType: Either "STOP" or "CONTINUE" - """ - try: - response = await self.client.responses.create( - model=self.model, - instructions=self.system_prompt, - input=[ - { - "role": "user", - "content": ( - f"Agent message: {agent_message}\n\nWhat is the appropriate response?" - ), - }, - ], - reasoning={"effort": "low"}, - max_output_tokens=256, - extra_headers={"Trace-Id": ""}, - ) - - text_parts: list[str] = [] - for item in response.output: - if item.type == "message": - text_parts.extend( - content.text - for content in item.content - if isinstance(content, ResponseOutputText) - ) - - response_text = "".join(text_parts) - if not response_text: - return "CONTINUE" - - response_text = response_text.strip().upper() - - if "STOP" in response_text: - return "STOP" - else: - return "CONTINUE" - - except Exception as e: - logger.warning("Auto-respond failed: %s", e) - return "CONTINUE" diff --git a/hud/agents/misc/response_automation.py b/hud/agents/misc/response_automation.py new file mode 100644 index 000000000..91621843f --- /dev/null +++ b/hud/agents/misc/response_automation.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import logging +from functools import cache +from typing import Literal + +import mcp.types as types +from openai import AsyncOpenAI +from openai.types.responses import ResponseOutputText + +from hud.settings import settings +from hud.telemetry import instrument + +logger = logging.getLogger(__name__) + +ResponseType = Literal["STOP", "CONTINUE"] + +DEFAULT_SYSTEM_PROMPT = """\ +You are an assistant that helps determine the appropriate response to an agent's message. + +You will receive messages from an agent that is performing tasks for a user. +Your job is to analyze these messages and respond with one of the following: + +- STOP: If the agent indicates it has successfully completed a task or is stuck, + struggling or says it cannot complete the task, even if phrased as a question + like "I have entered the right values into this form. Would you like me to do + anything else?" or "Here is the website. Is there any other information you + need?" or if the agent has strongly determined it wants to stop the task like + "The task is infeasible. Can I help you with something else?" + +- CONTINUE: If the agent is asking for clarification before proceeding with a task + like "I'm about to clear cookies from this website. Would you like me to proceed?" + or "I've entered the right values into this form. Would you like me to continue + with the rest of the task?" + +Respond ONLY with one of these two options.""" + + +async def auto_respond( + content: str | None, + *, + enabled: bool, +) -> types.PromptMessage | None: + if not enabled or not content: + return None + + try: + decision = await _determine_response(content) + except Exception as exc: + logger.warning("Auto-respond failed: %s", exc) + return None + + if decision == "STOP": + return None + + return types.PromptMessage( + role="user", + content=types.TextContent(text=decision, type="text"), + ) + + +@cache +def _client() -> AsyncOpenAI: + api_key = settings.api_key + if not api_key: + raise ValueError( + "HUD API key is required for auto_respond. Set HUD_API_KEY environment variable." + ) + + return AsyncOpenAI( + base_url=settings.hud_gateway_url, + api_key=api_key, + ) + + +@instrument( + category="agent", + name="response_automation", + internal_type="user-message", +) +async def _determine_response( + agent_message: str, + *, + model: str = "gpt-5", + system_prompt: str = DEFAULT_SYSTEM_PROMPT, +) -> ResponseType: + response = await _client().responses.create( + model=model, + instructions=system_prompt, + input=[ + { + "role": "user", + "content": f"Agent message: {agent_message}\n\nWhat is the appropriate response?", + }, + ], + reasoning={"effort": "low"}, + max_output_tokens=256, + extra_headers={"Trace-Id": ""}, + ) + + text_parts: list[str] = [] + for item in response.output: + if item.type == "message": + text_parts.extend( + content.text for content in item.content if isinstance(content, ResponseOutputText) + ) + + response_text = "".join(text_parts) + if not response_text: + return "CONTINUE" + + response_text = response_text.strip().upper() + return "STOP" if "STOP" in response_text else "CONTINUE" diff --git a/hud/agents/openai/agent.py b/hud/agents/openai/agent.py index 89bdfa50a..34ab08c27 100644 --- a/hud/agents/openai/agent.py +++ b/hud/agents/openai/agent.py @@ -2,86 +2,59 @@ from __future__ import annotations -import copy import json import logging -from inspect import cleandoc -from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from functools import cached_property +from typing import Any, Literal, cast import mcp.types as types from openai import AsyncOpenAI, Omit, OpenAI from openai.types.responses import ( - FunctionToolParam, - ResponseComputerToolCallOutputScreenshotParam, - ResponseFunctionCallOutputItemListParam, ResponseIncludable, - ResponseInputFileContentParam, - ResponseInputImageContentParam, ResponseInputImageParam, ResponseInputMessageContentListParam, ResponseInputParam, - ResponseInputTextContentParam, ResponseInputTextParam, ResponseOutputText, ToolParam, ) +from openai.types.responses.easy_input_message_param import EasyInputMessageParam from openai.types.responses.response_create_params import ToolChoice # noqa: TC002 from openai.types.responses.response_input_param import ( - ComputerCallOutput, - ComputerCallOutputAcknowledgedSafetyCheck, - FunctionCallOutput, Message, + ResponseInputItemParam, ) from openai.types.shared_params.reasoning import Reasoning # noqa: TC002 +from hud.agents import gateway from hud.agents.base import MCPAgent -from hud.agents.tools import ( - EnvironmentCapability, - call_agent_tools, - capabilities_metadata_from_context, - discover_environment_capabilities, - select_hosted_tools, -) -from hud.agents.types import OpenAIConfig, OpenAICreateParams +from hud.agents.types import OpenAIConfig from hud.settings import settings -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult, Trace -from hud.utils.strict_schema import ensure_strict_json_schema +from hud.types import AgentResponse, MCPToolCall from hud.utils.types import with_signature -from .tools import OpenAIHostedTool, OpenAIToolSearchTool, openai_tools - -if TYPE_CHECKING: - from .tools import OpenAITool +from .tools import OpenAIAgentTools logger = logging.getLogger(__name__) -class OpenAIAgent(MCPAgent): +class OpenAIAgent(MCPAgent[ResponseInputItemParam]): """Generic OpenAI agent that can execute MCP tools through the Responses API.""" - metadata: ClassVar[dict[str, Any] | None] = None - config_cls: ClassVar[type[BaseAgentConfig]] = OpenAIConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for OpenAI.""" - return AgentType.OPENAI - - @with_signature(OpenAICreateParams) + @with_signature(OpenAIConfig) @classmethod - def create(cls, **kwargs: Any) -> OpenAIAgent: # pyright: ignore[reportIncompatibleMethodOverride] - return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value] + def create(cls, **kwargs: object) -> OpenAIAgent: # pyright: ignore[reportIncompatibleMethodOverride] + return cls(OpenAIConfig.model_validate(kwargs)) - def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> None: - super().__init__(params, **kwargs) + def __init__(self, config: OpenAIConfig | None = None) -> None: + config = config or OpenAIConfig() + super().__init__(config) self.config: OpenAIConfig model_client = self.config.model_client if model_client is None: if settings.api_key: - from hud.agents.gateway import build_gateway_client - - model_client = build_gateway_client("openai") + model_client = gateway.build_gateway_client("openai") elif settings.openai_api_key: model_client = AsyncOpenAI(api_key=settings.openai_api_key) if self.config.validate_api_key: @@ -99,7 +72,7 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N " access" ) - self.openai_client: AsyncOpenAI = model_client + self.openai_client: AsyncOpenAI = cast("AsyncOpenAI", model_client) self._model = self.config.model self.max_output_tokens = self.config.max_output_tokens self.temperature = self.config.temperature @@ -109,178 +82,40 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N self.text = self.config.text self.truncation: Literal["auto", "disabled"] | None = self.config.truncation - self._openai_tools: list[ToolParam] = [] - self._openai_native_tools: dict[str, OpenAITool] = {} - self._tool_name_map: dict[str, str] = {} - self._environment_capabilities: dict[str, EnvironmentCapability] = {} - self._tool_search_threshold: int | None = None - self.last_response_id: str | None = None self._message_cursor = 0 - self.pending_call_id: str | None = None - self.pending_safety_checks: list[Any] = [] - - def _on_tools_ready(self) -> None: - """Build OpenAI-specific tool mappings after tools are discovered.""" - self._convert_tools_for_openai() - - def _discover_environment_capabilities( - self, tools: list[types.Tool] - ) -> dict[str, EnvironmentCapability]: - return discover_environment_capabilities( - tools, - env_metadata=capabilities_metadata_from_context(self.ctx), - name_fallbacks=openai_tools.name_fallbacks, - ) - - def _to_function_tool(self, tool: types.Tool) -> FunctionToolParam | None: - """Convert an MCP tool to OpenAI function tool format. - - Args: - tool: MCP tool to convert - - Returns: - OpenAI function tool parameter - """ - if tool.description is None or tool.inputSchema is None: - raise ValueError( - cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema. - Add these by: - 1. Adding a docstring to your @mcp.tool decorated function for the description - 2. Using pydantic Field() annotations on function parameters for the schema - """) - ) - - try: - strict_schema = ensure_strict_json_schema(copy.deepcopy(tool.inputSchema)) - except Exception as e: - self.console.warning_log(f"Failed to convert tool '{tool.name}' schema to strict: {e}") - return None - - return FunctionToolParam( - type="function", - name=tool.name, - description=tool.description, - parameters=strict_schema, - strict=True, - ) - - def _convert_tools_for_openai(self) -> None: - """Convert MCP tools into OpenAI Responses tool definitions.""" - self._openai_tools = [] - self._openai_native_tools = {} - self._tool_name_map = {} - self._tool_search_threshold = None - - categorized = self._categorized_tools - capabilities = self._discover_environment_capabilities(self.get_available_tools()) - self._environment_capabilities = capabilities - provider_backing_tools: set[str] = set() - - for capability in capabilities.values(): - if capability.name not in openai_tools.capabilities: - continue - openai_tool = openai_tools.tool_for_capability(capability, self.model) - if openai_tool is None: - continue - provider_backing_tools.add(capability.tool_name) - self._openai_native_tools[openai_tool.name] = openai_tool - self._tool_name_map[openai_tool.name] = openai_tool.name - self._openai_tools.append(openai_tool.to_params()) - - configured_hosted = select_hosted_tools( - self.config.hosted_tools, - tool_type=OpenAIHostedTool, - model=self.model, - ) - for hosted_tool in configured_hosted: - self._openai_tools.append(hosted_tool.to_params()) - if isinstance(hosted_tool, OpenAIToolSearchTool): - self._tool_search_threshold = hosted_tool.threshold - - # Process generic tools (function tools) - for tool in categorized.generic: - if tool.name in provider_backing_tools: - continue - openai_tool = self._to_function_tool(tool) - if openai_tool: - self._tool_name_map[tool.name] = tool.name - self._openai_tools.append(openai_tool) - - # Log actual tools being used - tool_names = sorted(self._tool_name_map.keys()) - self.console.info( - f"Agent initialized with {len(tool_names)} tools: {', '.join(tool_names)}" - ) - - def _extract_tool_call(self, item: Any) -> MCPToolCall | None: - """Extract an MCPToolCall from a response output item. - - Subclasses can override to customize tool call extraction (e.g., routing - computer_call to a different tool name). - """ - if item.type == "function_call": - tool_name = item.name or "" - target_name = self._tool_name_map.get(tool_name, tool_name) - arguments = json.loads(item.arguments) - return MCPToolCall(name=target_name, arguments=arguments, id=item.call_id) - elif item.type == "computer_call": - self.pending_safety_checks = item.pending_safety_checks or [] - target_name = self._tool_name_map.get("computer", "computer") - if hasattr(item, "actions") and item.actions: - arguments = {"actions": [a.to_dict() for a in item.actions]} - else: - arguments = item.action.to_dict() - return MCPToolCall(name=target_name, arguments=arguments, id=item.call_id) - elif item.type == "shell_call": - target_name = "shell" - return MCPToolCall(name=target_name, arguments=item.action.to_dict(), id=item.call_id) - return None - - async def call_tools( - self, tool_call: MCPToolCall | list[MCPToolCall] | None = None - ) -> list[MCPToolResult]: - """Route OpenAI provider tools through their agent-owned adapters.""" - return await call_agent_tools(self, self._openai_native_tools, tool_call) - - async def _run_context( - self, context: list[types.ContentBlock], *, max_steps: int = 10 - ) -> Trace: - """Reset internal state before delegating to the base loop.""" - self._reset_response_state() - return await super()._run_context(context, max_steps=max_steps) - - def _reset_response_state(self) -> None: - self.last_response_id = None - self._message_cursor = 0 - self.pending_call_id = None - self.pending_safety_checks = [] - - async def get_system_messages(self) -> list[types.ContentBlock]: - """System messages are provided via the `instructions` field.""" - return [] + @cached_property + def tools(self) -> OpenAIAgentTools: + return OpenAIAgentTools() + + async def format_messages( + self, messages: list[types.PromptMessage] + ) -> list[ResponseInputItemParam]: + """Convert MCP prompt messages into OpenAI Responses input items.""" + formatted_messages: list[ResponseInputItemParam] = [] + for message in messages: + match message.content: + case types.TextContent() as block: + content: ResponseInputMessageContentListParam = [ + ResponseInputTextParam(type="input_text", text=block.text) + ] + case types.ImageContent() as block: + mime_type = getattr(block, "mimeType", "image/png") + content = [ + ResponseInputImageParam( + type="input_image", + image_url=f"data:{mime_type};base64,{block.data}", + detail="auto", + ) + ] + case _: + content = [ResponseInputTextParam(type="input_text", text="")] - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Message]: - """Convert MCP content blocks into OpenAI user messages.""" - content: ResponseInputMessageContentListParam = [] - for block in blocks: - if isinstance(block, types.TextContent): - content.append(ResponseInputTextParam(type="input_text", text=block.text)) - elif isinstance(block, types.ImageContent): - mime_type = getattr(block, "mimeType", "image/png") - content.append( - ResponseInputImageParam( - type="input_image", - image_url=f"data:{mime_type};base64,{block.data}", - detail="auto", - ) - ) - if not content: - content.append(ResponseInputTextParam(type="input_text", text="")) - return [Message(role="user", content=content)] + formatted_messages.append(EasyInputMessageParam(role=message.role, content=content)) + return formatted_messages - async def get_response(self, messages: ResponseInputParam) -> InferenceResult: + async def get_response(self, messages: list[ResponseInputItemParam]) -> AgentResponse: """Send the latest input items to OpenAI's Responses API.""" new_items: ResponseInputParam = messages[self._message_cursor :] if not new_items: @@ -291,33 +126,29 @@ async def get_response(self, messages: ResponseInputParam) -> InferenceResult: ) ] else: - self.console.debug("No new messages to send to OpenAI.") - return InferenceResult(content="", tool_calls=[], done=True) + logger.debug("No new messages to send to OpenAI.") + return AgentResponse(content="", tool_calls=[], done=True) - scenario_enable_citations = bool( - getattr(self.ctx, "scenario_enable_citations", False) if self.ctx is not None else False - ) include_param: list[ResponseIncludable] | Omit = Omit() - if scenario_enable_citations: + if self.enable_citations: include_param = ["web_search_call.action.sources"] - effective_tools: list[ToolParam] = list(self._openai_tools) - if self._tool_search_threshold is not None: - fn_count = sum( - 1 for t in effective_tools if isinstance(t, dict) and t.get("type") == "function" - ) - if fn_count > self._tool_search_threshold: + effective_tools: list[ToolParam] = list(self.tools.params) + if self.tools.tool_search_threshold is not None: + fn_count = sum(1 for t in effective_tools if t.get("type") == "function") + if fn_count > self.tools.tool_search_threshold: logger.debug( "tool_search: %d function tools > threshold %d, applying defer_loading", fn_count, - self._tool_search_threshold, + self.tools.tool_search_threshold, + ) + effective_tools = cast( + "list[ToolParam]", + [ + {**t, "defer_loading": True} if t.get("type") == "function" else t + for t in effective_tools + ], ) - effective_tools = [ # type: ignore[assignment] - {**t, "defer_loading": True} - if isinstance(t, dict) and t.get("type") == "function" - else t - for t in effective_tools - ] response = await self.openai_client.responses.create( model=self._model, @@ -340,227 +171,89 @@ async def get_response(self, messages: ResponseInputParam) -> InferenceResult: self.last_response_id = response.id self._message_cursor = len(messages) - agent_response = InferenceResult(content="", tool_calls=[], done=True) text_chunks: list[str] = [] reasoning_chunks: list[str] = [] - - citations: list[dict[str, Any]] = [] + citations: list[dict[str, object]] = [] + tool_calls: list[MCPToolCall] = [] for item in response.output: - if item.type == "message": - for content_block in item.content: - if isinstance(content_block, ResponseOutputText): + match item.type: + case "message": + for content_block in item.content: + if not isinstance(content_block, ResponseOutputText): + continue if content_block.text: text_chunks.append(content_block.text) - # Extract citations from annotations - if content_block.annotations: - for ann in content_block.annotations: - ann_type = getattr(ann, "type", "") - if ann_type == "url_citation": - cit_obj = getattr(ann, "url_citation", ann) + for ann in content_block.annotations or []: + match ann.type: + case "url_citation": + citation = ann citations.append( { "type": "url_citation", - "text": getattr(cit_obj, "title", "") or "", - "source": getattr(cit_obj, "url", "") or "", - "title": getattr(cit_obj, "title", None), - "start_index": getattr(ann, "start_index", None), - "end_index": getattr(ann, "end_index", None), + "text": citation.title, + "source": citation.url, + "title": citation.title, + "start_index": citation.start_index, + "end_index": citation.end_index, } ) - elif ann_type == "file_citation": - cit_obj = getattr(ann, "file_citation", ann) + case "file_citation": + citation = ann citations.append( { "type": "file_citation", - "text": getattr(cit_obj, "filename", "") or "", - "source": getattr(cit_obj, "file_id", "") or "", - "title": getattr(cit_obj, "filename", None), - "start_index": getattr(ann, "start_index", None), - "end_index": getattr(ann, "end_index", None), + "text": citation.filename, + "source": citation.file_id, + "title": citation.filename, } ) - elif item.type == "reasoning": - reasoning_chunks.append("".join(summary.text for summary in item.summary)) - else: - tool_call = self._extract_tool_call(item) - if tool_call is not None: - agent_response.tool_calls.append(tool_call) - - if agent_response.tool_calls: - agent_response.done = False - - agent_response.content = "".join(text_chunks) - agent_response.citations = citations - if reasoning_chunks: - agent_response.reasoning = "\n".join(reasoning_chunks) - return agent_response - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[Any]: - """Convert MCP tool outputs into Responses input items. - - Detects computer tool results and formats them as ComputerCallOutput - with screenshots. Non-computer calls are formatted as FunctionCallOutput. - """ - computer_tool_name = self._tool_name_map.get("computer") - has_computer_call = bool(computer_tool_name) and any( - c.name == computer_tool_name for c in tool_calls - ) - has_native_call = any(c.name in self._openai_native_tools for c in tool_calls) - if not has_computer_call and not has_native_call: - return list(await self._format_function_results(tool_calls, tool_results)) - - remaining_calls: list[MCPToolCall] = [] - remaining_results: list[MCPToolResult] = [] - computer_outputs: list[ComputerCallOutput] = [] - native_outputs: list[dict[str, Any]] = [] - ordering: list[tuple[str, int]] = [] - - for call, result in zip(tool_calls, tool_results, strict=False): - if call.name == computer_tool_name: - screenshot = self._extract_latest_screenshot(result) - if not screenshot: - raise ValueError( - "Computer tool result missing screenshot. " - "The tool must always return a screenshot for computer_call_output." - ) - call_id = call.id or self.pending_call_id - if not call_id: - self.console.warning_log("Computer tool call missing ID; skipping output.") - continue - acknowledged_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] = [] - for check in self.pending_safety_checks: - if hasattr(check, "model_dump"): - acknowledged_checks.append(check.model_dump()) # type: ignore[arg-type] - elif isinstance(check, dict): - acknowledged_checks.append(check) # type: ignore[arg-type] - output_payload = ComputerCallOutput( - type="computer_call_output", - call_id=call_id, - output=cast( - "ResponseComputerToolCallOutputScreenshotParam", - { - "type": "computer_screenshot", - "image_url": f"data:image/png;base64,{screenshot}", - "detail": "original", - }, - ), - ) - if acknowledged_checks: - output_payload["acknowledged_safety_checks"] = acknowledged_checks - computer_outputs.append(output_payload) - self.pending_call_id = None - self.pending_safety_checks = [] - ordering.append(("computer", len(computer_outputs) - 1)) - elif call.name in self._openai_native_tools: - native_outputs.append( - self._openai_native_tools[call.name].format_result(call, result) - ) - ordering.append(("native", len(native_outputs) - 1)) - else: - remaining_calls.append(call) - remaining_results.append(result) - ordering.append(("function", len(remaining_calls) - 1)) - - formatted: list[Any] = [] - function_outputs: list[FunctionCallOutput] = [] - if remaining_calls: - function_outputs = await self._format_function_results( - remaining_calls, remaining_results - ) - - for kind, idx in ordering: - if kind == "computer" and idx < len(computer_outputs): - formatted.append(computer_outputs[idx]) - elif kind == "native" and idx < len(native_outputs): - formatted.append(native_outputs[idx]) - elif kind == "function" and idx < len(function_outputs): - formatted.append(function_outputs[idx]) - return formatted - - def _extract_latest_screenshot(self, result: MCPToolResult) -> str | None: - """Extract the latest screenshot from a tool result.""" - if not result.content: - return None - for content in reversed(result.content): - if isinstance(content, types.ImageContent): - return content.data - if isinstance(content, types.TextContent) and result.isError: - self.console.error_log(f"Computer tool error: {content.text}") - return None - - async def _format_function_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[FunctionCallOutput]: - """Convert MCP tool outputs into function call output items.""" - formatted: list[FunctionCallOutput] = [] - for call, result in zip(tool_calls, tool_results, strict=False): - if not call.id: - self.console.warning_log(f"Tool '{call.name}' missing call_id; skipping output.") - continue - - output_items: ResponseFunctionCallOutputItemListParam = [] - if result.isError: - output_items.append( - ResponseInputTextParam(type="input_text", text="[tool_error] true") - ) - - if result.structuredContent is not None: - output_items.append( - ResponseInputTextParam( - type="input_text", text=json.dumps(result.structuredContent, default=str) - ) - ) - - for block in result.content: - match block: - case types.TextContent(): - output_items.append( - ResponseInputTextContentParam(type="input_text", text=block.text) + case _: + continue + case "reasoning": + reasoning_chunks.append("".join(summary.text for summary in item.summary)) + case "function_call": + tool_name = item.name or "" + tool_calls.append( + MCPToolCall( + name=self.tools.name_map.get(tool_name, tool_name), + arguments=json.loads(item.arguments), + id=item.call_id, ) - case types.ImageContent(): - mime_type = getattr(block, "mimeType", "image/png") - output_items.append( - ResponseInputImageContentParam( - type="input_image", - image_url=f"data:{mime_type};base64,{block.data}", - ) - ) - case types.ResourceLink(): - output_items.append( - ResponseInputFileContentParam( - type="input_file", file_url=str(block.uri) - ) + ) + case "computer_call": + if item.actions: + arguments = {"actions": [action.to_dict() for action in item.actions]} + elif item.action is not None: + arguments = item.action.to_dict() + else: + raise ValueError("OpenAI computer_call missing action") + call: dict[str, Any] = { + "name": self.tools.name_map.get("computer", "computer"), + "arguments": arguments, + "id": item.call_id, + } + if item.pending_safety_checks: + call["pending_safety_checks"] = [ + check.model_dump() if hasattr(check, "model_dump") else check + for check in item.pending_safety_checks + ] + tool_calls.append(MCPToolCall.model_validate(call)) + case "shell_call": + tool_calls.append( + MCPToolCall( + name="shell", + arguments=item.action.to_dict(), + id=item.call_id, ) - case types.EmbeddedResource(): - match block.resource: - case types.TextResourceContents(): - output_items.append( - ResponseInputTextContentParam( - type="input_text", text=block.resource.text - ) - ) - case types.BlobResourceContents(): - output_items.append( - ResponseInputFileContentParam( - type="input_file", file_data=block.resource.blob - ) - ) - case _: - self.console.warning_log( - f"Unknown resource type: {type(block.resource)}" - ) - case _: - self.console.warning_log(f"Unknown content block type: {type(block)}") - - if not output_items: - output_items.append(ResponseInputTextParam(type="input_text", text="")) + ) + case _: + continue - formatted.append( - FunctionCallOutput( - type="function_call_output", call_id=call.id, output=output_items - ), - ) - return formatted + return AgentResponse( + content="".join(text_chunks), + reasoning="\n".join(reasoning_chunks) if reasoning_chunks else None, + citations=citations, + tool_calls=tool_calls, + done=not tool_calls, + ) diff --git a/hud/agents/openai/tools/__init__.py b/hud/agents/openai/tools/__init__.py index 1c1ffe271..b2e5222d7 100644 --- a/hud/agents/openai/tools/__init__.py +++ b/hud/agents/openai/tools/__init__.py @@ -2,55 +2,46 @@ from __future__ import annotations -from dataclasses import dataclass, field +from typing import TYPE_CHECKING, ClassVar -from hud.agents.tools import AgentToolRegistry +from openai.types.responses import ToolParam -from .base import OpenAITool -from .coding import ( - OPENAI_SHELL_SPEC, - OpenAIShellTool, -) -from .computer import OPENAI_COMPUTER_SPEC, OpenAIComputerTool +from hud.agents.tools import AgentTool, AgentTools + +from .base import OpenAIFunctionTool, OpenAITool +from .coding import OpenAIShellTool +from .computer import OpenAIComputerTool from .hosted import OpenAICodeInterpreterTool, OpenAIHostedTool, OpenAIToolSearchTool +if TYPE_CHECKING: + from collections.abc import Mapping + -@dataclass(frozen=True) -class OpenAIToolRegistry(AgentToolRegistry[OpenAITool]): - """Registry for OpenAI harness tools.""" +class OpenAIAgentTools(AgentTools[OpenAITool, ToolParam]): + """Prepared OpenAI Responses tool state for a run.""" - tool_classes: tuple[type[OpenAITool], ...] = ( + native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = ( OpenAIComputerTool, OpenAIShellTool, ) - name_fallbacks: dict[str, tuple[str, ...]] = field( - default_factory=lambda: { - "computer": ("computer", "openai_computer"), - "shell": ("bash",), - "editor": ("edit",), - } - ) + function_tool_class = OpenAIFunctionTool + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = { + "computer": ("computer", "openai_computer"), + "shell": ("bash",), + "editor": ("edit",), + } @property - def api_types(self) -> frozenset[str]: - return frozenset(cls.name for cls in self.tool_classes) - - @property - def roles(self) -> frozenset[str]: - return self.capabilities - + def tool_search_threshold(self) -> int | None: + for hosted_tool in self.hosted_tools: + if isinstance(hosted_tool, OpenAIToolSearchTool): + return hosted_tool.threshold + return None -openai_tools = OpenAIToolRegistry() __all__ = [ - "OPENAI_COMPUTER_SPEC", - "OPENAI_SHELL_SPEC", + "OpenAIAgentTools", "OpenAICodeInterpreterTool", - "OpenAIComputerTool", "OpenAIHostedTool", - "OpenAIShellTool", - "OpenAITool", - "OpenAIToolRegistry", "OpenAIToolSearchTool", - "openai_tools", ] diff --git a/hud/agents/openai/tools/apply_patch.py b/hud/agents/openai/tools/apply_patch.py index 90913df5e..03fffa654 100644 --- a/hud/agents/openai/tools/apply_patch.py +++ b/hud/agents/openai/tools/apply_patch.py @@ -1,10 +1,10 @@ +# pyright: reportUnusedFunction=false """OpenAI apply_patch parser helpers.""" from __future__ import annotations from dataclasses import dataclass, field -from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from collections.abc import Callable @@ -14,45 +14,24 @@ class DiffError(ValueError): """Exception raised when diff parsing or application fails.""" -class ActionType(str, Enum): - ADD = "add" - DELETE = "delete" - UPDATE = "update" - - -@dataclass -class FileChange: - type: ActionType - old_content: str | None = None - new_content: str | None = None - move_path: str | None = None - - -@dataclass -class Commit: - changes: dict[str, FileChange] = field(default_factory=dict) +ActionType = Literal["add", "delete", "update"] @dataclass class Chunk: orig_index: int = -1 # line index of the first line in the original file - del_lines: list[str] = field(default_factory=list) - ins_lines: list[str] = field(default_factory=list) + del_lines: list[str] = field(default_factory=list[str]) + ins_lines: list[str] = field(default_factory=list[str]) @dataclass class PatchAction: type: ActionType new_file: str | None = None - chunks: list[Chunk] = field(default_factory=list) + chunks: list[Chunk] = field(default_factory=list[Chunk]) move_path: str | None = None -@dataclass -class Patch: - actions: dict[str, PatchAction] = field(default_factory=dict) - - class Parser: """Parser for V4A diff format.""" @@ -60,7 +39,7 @@ def __init__(self, current_files: dict[str, str], lines: list[str], index: int = self.current_files = current_files self.lines = lines self.index = index - self.patch = Patch() + self.actions: dict[str, PatchAction] = {} self.fuzz = 0 def is_done(self, prefixes: tuple[str, ...] | None = None) -> bool: @@ -68,19 +47,11 @@ def is_done(self, prefixes: tuple[str, ...] | None = None) -> bool: return True return prefixes is not None and self.lines[self.index].startswith(prefixes) - def startswith(self, prefix: str | tuple[str, ...]) -> bool: - if self.index >= len(self.lines): - raise DiffError(f"Unexpected end of patch at index {self.index}") - return self.lines[self.index].startswith(prefix) - - def read_str(self, prefix: str = "", return_everything: bool = False) -> str: + def read_str(self, prefix: str = "") -> str: if self.index >= len(self.lines): return "" # At EOF, no match possible if self.lines[self.index].startswith(prefix): - if return_everything: - text = self.lines[self.index] - else: - text = self.lines[self.index][len(prefix) :] + text = self.lines[self.index][len(prefix) :] self.index += 1 return text return "" @@ -89,7 +60,7 @@ def parse(self) -> None: while not self.is_done(("*** End Patch",)): path = self.read_str("*** Update File: ") if path: - if path in self.patch.actions: + if path in self.actions: raise DiffError(f"Update File Error: Duplicate Path: {path}") move_to = self.read_str("*** Move to: ") if path not in self.current_files: @@ -97,33 +68,33 @@ def parse(self) -> None: text = self.current_files[path] action = self.parse_update_file(text) action.move_path = move_to if move_to else None - self.patch.actions[path] = action + self.actions[path] = action continue path = self.read_str("*** Delete File: ") if path: - if path in self.patch.actions: + if path in self.actions: raise DiffError(f"Delete File Error: Duplicate Path: {path}") if path not in self.current_files: raise DiffError(f"Delete File Error: Missing File: {path}") - self.patch.actions[path] = PatchAction(type=ActionType.DELETE) + self.actions[path] = PatchAction(type="delete") continue path = self.read_str("*** Add File: ") if path: - if path in self.patch.actions: + if path in self.actions: raise DiffError(f"Add File Error: Duplicate Path: {path}") - self.patch.actions[path] = self.parse_add_file() + self.actions[path] = self.parse_add_file() continue raise DiffError(f"Unknown Line: {self.lines[self.index]}") - if self.index >= len(self.lines) or not self.startswith("*** End Patch"): + if self.index >= len(self.lines) or not self.lines[self.index].startswith("*** End Patch"): raise DiffError("Missing End Patch") self.index += 1 def parse_update_file(self, text: str) -> PatchAction: - action = PatchAction(type=ActionType.UPDATE) + action = PatchAction(type="update") lines = text.split("\n") index = 0 @@ -136,27 +107,28 @@ def parse_update_file(self, text: str) -> PatchAction: "*** End of File", ) ): - def_str = self.read_str("@@ ") - section_str = "" - if not def_str and self.lines[self.index] == "@@": - section_str = self.lines[self.index] + section_anchor = self.read_str("@@ ") + has_section_marker = False + if not section_anchor and self.lines[self.index] == "@@": + has_section_marker = True self.index += 1 - if not (def_str or section_str or index == 0): + if not (section_anchor or has_section_marker or index == 0): raise DiffError(f"Invalid Line:\n{self.lines[self.index]}") - if def_str.strip(): + if section_anchor.strip(): found = False - if not [s for s in lines[:index] if s == def_str]: - for i, s in enumerate(lines[index:], index): - if s == def_str: + if not any(line == section_anchor for line in lines[:index]): + for i, line in enumerate(lines[index:], index): + if line == section_anchor: index = i + 1 found = True break - if not found and not [s for s in lines[:index] if s.strip() == def_str.strip()]: - for i, s in enumerate(lines[index:], index): - if s.strip() == def_str.strip(): + stripped_anchor = section_anchor.strip() + if not found and not any(line.strip() == stripped_anchor for line in lines[:index]): + for i, line in enumerate(lines[index:], index): + if line.strip() == stripped_anchor: index = i + 1 self.fuzz += 1 found = True @@ -174,9 +146,9 @@ def parse_update_file(self, text: str) -> PatchAction: self.fuzz += fuzz - for ch in chunks: - ch.orig_index += new_index - action.chunks.append(ch) + for chunk in chunks: + chunk.orig_index += new_index + action.chunks.append(chunk) index = new_index + len(next_chunk_context) self.index = end_patch_index @@ -184,16 +156,15 @@ def parse_update_file(self, text: str) -> PatchAction: return action def parse_add_file(self) -> PatchAction: - lines = [] + lines: list[str] = [] while not self.is_done( ("*** End Patch", "*** Update File:", "*** Delete File:", "*** Add File:") ): - s = self.read_str() - if not s.startswith("+"): - raise DiffError(f"Invalid Add File Line: {s}") - s = s[1:] - lines.append(s) - return PatchAction(type=ActionType.ADD, new_file="\n".join(lines)) + line = self.read_str() + if not line.startswith("+"): + raise DiffError(f"Invalid Add File Line: {line}") + lines.append(line[1:]) + return PatchAction(type="add", new_file="\n".join(lines)) def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]: old: list[str] = [] @@ -204,9 +175,23 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]: orig_index = self.index index = self.index + def flush_chunk() -> None: + nonlocal del_lines, ins_lines + if not (ins_lines or del_lines): + return + chunks.append( + Chunk( + orig_index=len(old) - len(del_lines), + del_lines=del_lines, + ins_lines=ins_lines, + ) + ) + del_lines = [] + ins_lines = [] + while index < len(self.lines): - s = self.lines[index] - if s.startswith( + line = self.lines[index] + if line.startswith( ( "@@", "*** End Patch", @@ -217,56 +202,40 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]: ) ): break - if s == "***": + if line == "***": break - elif s.startswith("***"): - raise DiffError(f"Invalid Line: {s}") + elif line.startswith("***"): + raise DiffError(f"Invalid Line: {line}") index += 1 last_mode = mode - if s == "": - s = " " + if line == "": + line = " " - if s[0] == "+": + if line[0] == "+": mode = "add" - elif s[0] == "-": + elif line[0] == "-": mode = "delete" - elif s[0] == " ": + elif line[0] == " ": mode = "keep" else: - raise DiffError(f"Invalid Line: {s}") + raise DiffError(f"Invalid Line: {line}") - s = s[1:] + line = line[1:] if mode == "keep" and last_mode != mode: - if ins_lines or del_lines: - chunks.append( - Chunk( - orig_index=len(old) - len(del_lines), - del_lines=del_lines, - ins_lines=ins_lines, - ) - ) - del_lines = [] - ins_lines = [] + flush_chunk() if mode == "delete": - del_lines.append(s) - old.append(s) + del_lines.append(line) + old.append(line) elif mode == "add": - ins_lines.append(s) + ins_lines.append(line) elif mode == "keep": - old.append(s) + old.append(line) - if ins_lines or del_lines: - chunks.append( - Chunk( - orig_index=len(old) - len(del_lines), - del_lines=del_lines, - ins_lines=ins_lines, - ) - ) + flush_chunk() if index < len(self.lines) and self.lines[index] == "*** End of File": index += 1 @@ -278,116 +247,82 @@ def _peek_next_section(self) -> tuple[list[str], list[Chunk], int, bool]: return old, chunks, index, False -def _find_context_core(lines: list[str], context: list[str], start: int) -> tuple[int, int]: +def _find_context(lines: list[str], context: list[str], start: int, eof: bool) -> tuple[int, int]: if not context: return start, 0 - # Prefer identical - for i in range(start, len(lines)): - if lines[i : i + len(context)] == context: - return i, 0 - - # RStrip is ok - for i in range(start, len(lines)): - if [s.rstrip() for s in lines[i : i + len(context)]] == [s.rstrip() for s in context]: - return i, 1 - - # Fine, Strip is ok too - for i in range(start, len(lines)): - if [s.strip() for s in lines[i : i + len(context)]] == [s.strip() for s in context]: - return i, 100 - - return -1, 0 - + search_starts = [len(lines) - len(context), start] if eof else [start] + rstripped_context = [line.rstrip() for line in context] + stripped_context = [line.strip() for line in context] -def _find_context(lines: list[str], context: list[str], start: int, eof: bool) -> tuple[int, int]: - if eof: - new_index, fuzz = _find_context_core(lines, context, len(lines) - len(context)) - if new_index != -1: - return new_index, fuzz - new_index, fuzz = _find_context_core(lines, context, start) - return new_index, fuzz + 10000 - return _find_context_core(lines, context, start) - - -def _get_updated_file(text: str, action: PatchAction, path: str) -> str: - assert action.type == ActionType.UPDATE - orig_lines = text.split("\n") - dest_lines = [] - orig_index = 0 - - for chunk in action.chunks: - if chunk.orig_index > len(orig_lines): - raise DiffError( - f"_get_updated_file: {path}: chunk.orig_index {chunk.orig_index} " - f"> len(lines) {len(orig_lines)}" - ) - if orig_index > chunk.orig_index: - raise DiffError( - f"_get_updated_file: {path}: orig_index {orig_index} " - f"> chunk.orig_index {chunk.orig_index}" - ) + for attempt, search_start in enumerate(search_starts): + fuzz_offset = 10000 if eof and attempt > 0 else 0 - dest_lines.extend(orig_lines[orig_index : chunk.orig_index]) - orig_index = chunk.orig_index + for i in range(search_start, len(lines)): + candidate = lines[i : i + len(context)] + if candidate == context: + return i, fuzz_offset - if chunk.ins_lines: - dest_lines.extend(chunk.ins_lines) + for i in range(search_start, len(lines)): + candidate = lines[i : i + len(context)] + if [line.rstrip() for line in candidate] == rstripped_context: + return i, fuzz_offset + 1 - orig_index += len(chunk.del_lines) + for i in range(search_start, len(lines)): + candidate = lines[i : i + len(context)] + if [line.strip() for line in candidate] == stripped_context: + return i, fuzz_offset + 100 - dest_lines.extend(orig_lines[orig_index:]) - return "\n".join(dest_lines) + return -1, 0 -def _text_to_patch(text: str, orig: dict[str, str]) -> tuple[Patch, int]: +def _text_to_patch(text: str, orig: dict[str, str]) -> tuple[dict[str, PatchAction], int]: lines = text.strip().split("\n") if len(lines) < 2 or not lines[0].startswith("*** Begin Patch") or lines[-1] != "*** End Patch": raise DiffError("Invalid patch text") parser = Parser(current_files=orig, lines=lines, index=1) parser.parse() - return parser.patch, parser.fuzz + return parser.actions, parser.fuzz + + +def _apply_patch( + patch: dict[str, PatchAction], + orig: dict[str, str], + write_fn: Callable[[str, str | None], None], + remove_fn: Callable[[str], None], +) -> None: + for path, action in patch.items(): + match action.type: + case "delete": + remove_fn(path) + case "add": + write_fn(path, action.new_file) + case "update": + orig_lines = orig[path].split("\n") + dest_lines: list[str] = [] + orig_index = 0 + + for chunk in action.chunks: + if chunk.orig_index > len(orig_lines): + raise DiffError( + f"_apply_patch: {path}: chunk.orig_index {chunk.orig_index} " + f"> len(lines) {len(orig_lines)}" + ) + if orig_index > chunk.orig_index: + raise DiffError( + f"_apply_patch: {path}: orig_index {orig_index} " + f"> chunk.orig_index {chunk.orig_index}" + ) + dest_lines.extend(orig_lines[orig_index : chunk.orig_index]) + dest_lines.extend(chunk.ins_lines) + orig_index = chunk.orig_index + len(chunk.del_lines) -def _identify_files_needed(text: str) -> list[str]: - lines = text.strip().split("\n") - result = set() - for line in lines: - if line.startswith("*** Update File: "): - result.add(line[len("*** Update File: ") :]) - if line.startswith("*** Delete File: "): - result.add(line[len("*** Delete File: ") :]) - return list(result) - - -def _patch_to_commit(patch: Patch, orig: dict[str, str]) -> Commit: - commit = Commit() - for path, action in patch.actions.items(): - if action.type == ActionType.DELETE: - commit.changes[path] = FileChange(type=ActionType.DELETE, old_content=orig[path]) - elif action.type == ActionType.ADD: - commit.changes[path] = FileChange(type=ActionType.ADD, new_content=action.new_file) - elif action.type == ActionType.UPDATE: - new_content = _get_updated_file(text=orig[path], action=action, path=path) - commit.changes[path] = FileChange( - type=ActionType.UPDATE, - old_content=orig[path], - new_content=new_content, - move_path=action.move_path, - ) - return commit - - -def _apply_commit(commit: Commit, write_fn: Callable, remove_fn: Callable) -> None: - for path, change in commit.changes.items(): - if change.type == ActionType.DELETE: - remove_fn(path) - elif change.type == ActionType.ADD: - write_fn(path, change.new_content) - elif change.type == ActionType.UPDATE: - if change.move_path: - write_fn(change.move_path, change.new_content) - remove_fn(path) - else: - write_fn(path, change.new_content) + dest_lines.extend(orig_lines[orig_index:]) + new_content = "\n".join(dest_lines) + if action.move_path: + write_fn(action.move_path, new_content) + remove_fn(path) + else: + write_fn(path, new_content) diff --git a/hud/agents/openai/tools/base.py b/hud/agents/openai/tools/base.py index f5074bb4c..523a5087e 100644 --- a/hud/agents/openai/tools/base.py +++ b/hud/agents/openai/tools/base.py @@ -2,41 +2,155 @@ from __future__ import annotations +import copy +import json +import logging from abc import ABC -from typing import TYPE_CHECKING, Any +from inspect import cleandoc +from typing import TYPE_CHECKING, Any, cast -from mcp.types import TextContent +from mcp import types +from openai.types.responses import ( + FunctionToolParam, + ResponseFunctionCallOutputItemListParam, + ResponseInputFileContentParam, + ResponseInputImageContentParam, + ResponseInputTextContentParam, + ResponseInputTextParam, + ToolParam, +) +from openai.types.responses.response_input_param import FunctionCallOutput -from hud.agents import tools as _agent_tools -from hud.agents.tools import AgentTool, AgentToolSpec, CallTool +from hud.agents.tools import AgentTool, AgentToolSpec +from hud.utils.strict_schema import ensure_strict_json_schema if TYPE_CHECKING: - from openai.types.responses import ToolParam + from openai.types.responses import ResponseInputItemParam from hud.types import MCPToolCall, MCPToolResult -else: - ToolParam = Any + +logger = logging.getLogger(__name__) OpenAIToolSpec = AgentToolSpec -call_tool = _agent_tools.call_tool -class OpenAITool(AgentTool["ToolParam"], ABC): +class OpenAITool(AgentTool[ToolParam], ABC): """Agent-side OpenAI provider tool backed by an environment tool.""" - def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]: + def format_result( + self, call: MCPToolCall, result: MCPToolResult + ) -> ResponseInputItemParam | None: """Format a generic provider tool result for the OpenAI Responses API.""" - return { - "type": "function_call_output", - "call_id": call.id, - "output": result_text(result), - } + if not call.id: + logger.warning("Tool '%s' missing call_id; skipping output.", call.name) + return None + + output_items: ResponseFunctionCallOutputItemListParam = [] + if result.isError: + output_items.append( + ResponseInputTextContentParam(type="input_text", text="[tool_error] true") + ) + + if result.structuredContent is not None: + output_items.append( + ResponseInputTextContentParam( + type="input_text", + text=json.dumps(result.structuredContent, default=str), + ) + ) + + for block in result.content: + match block: + case types.TextContent(): + output_items.append( + ResponseInputTextContentParam(type="input_text", text=block.text) + ) + case types.ImageContent(): + mime_type = getattr(block, "mimeType", "image/png") + output_items.append( + ResponseInputImageContentParam( + type="input_image", + image_url=f"data:{mime_type};base64,{block.data}", + ) + ) + case types.ResourceLink(): + output_items.append( + ResponseInputFileContentParam(type="input_file", file_url=str(block.uri)) + ) + case types.EmbeddedResource(resource=types.TextResourceContents() as resource): + output_items.append( + ResponseInputTextContentParam(type="input_text", text=resource.text) + ) + case types.EmbeddedResource(resource=types.BlobResourceContents() as resource): + output_items.append( + ResponseInputFileContentParam(type="input_file", file_data=resource.blob) + ) + case types.EmbeddedResource(): + logger.warning("Unknown resource type: %s", type(block.resource)) + case _: + logger.warning("Unknown content block type: %s", type(block)) + + if not output_items: + output_items.append(ResponseInputTextParam(type="input_text", text="")) + + return FunctionCallOutput(type="function_call_output", call_id=call.id, output=output_items) + + +class OpenAIFunctionTool(OpenAITool): + """Generic OpenAI function tool backed by an MCP tool.""" + + name = "function" + capability = "function" + + def __init__( + self, + *, + env_tool_name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + super().__init__( + env_tool_name=env_tool_name, + spec=OpenAIToolSpec(api_type="function", api_name=env_tool_name), + ) + self.description = description + self.parameters = parameters + + @classmethod + def from_tool(cls, tool: types.Tool) -> OpenAIFunctionTool | None: + if tool.description is None: + raise ValueError( + cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema. + Add these by: + 1. Adding a docstring to your @mcp.tool decorated function for the description + 2. Using pydantic Field() annotations on function parameters for the schema + """) + ) + try: + parameters = ensure_strict_json_schema(copy.deepcopy(tool.inputSchema)) + except Exception as e: + logger.warning("Failed to convert tool '%s' schema to strict: %s", tool.name, e) + return None -def result_text(result: MCPToolResult) -> str: - """Return text content from an MCP tool result.""" - parts = [block.text for block in result.content if isinstance(block, TextContent)] - return "\n".join(part for part in parts if part) + return cls( + env_tool_name=tool.name, + description=tool.description, + parameters=parameters, + ) + @property + def provider_name(self) -> str: + return self.env_tool_name -__all__ = ["CallTool", "OpenAITool", "OpenAIToolSpec", "call_tool", "result_text"] + def to_params(self) -> ToolParam: + return cast( + "ToolParam", + FunctionToolParam( + type="function", + name=self.provider_name, + description=self.description, + parameters=self.parameters, + strict=True, + ), + ) diff --git a/hud/agents/openai/tools/coding.py b/hud/agents/openai/tools/coding.py index 0fa2f6176..6bb6efa4d 100644 --- a/hud/agents/openai/tools/coding.py +++ b/hud/agents/openai/tools/coding.py @@ -2,14 +2,17 @@ from __future__ import annotations -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from mcp.types import TextContent -from openai.types.responses import FunctionShellToolParam, ToolParam +from openai.types.responses import FunctionShellToolParam, ResponseInputItemParam, ToolParam from hud.types import MCPToolCall, MCPToolResult -from .base import CallTool, OpenAITool, OpenAIToolSpec, call_tool, result_text +from .base import OpenAITool, OpenAIToolSpec + +if TYPE_CHECKING: + from hud.agents.tools.base import CallTool OPENAI_SHELL_SPEC = OpenAIToolSpec( api_type="shell", @@ -45,21 +48,28 @@ def to_params(self) -> ToolParam: FunctionShellToolParam(type="shell", environment={"type": "local"}), ) - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - commands = arguments.get("commands") - if isinstance(commands, str): - commands = [commands] - if not isinstance(commands, list) or not all(isinstance(cmd, str) for cmd in commands): - return _provider_result( - "shell", - "commands must be a list of strings", + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + def invalid_commands_result() -> MCPToolResult: + text = "commands must be a list of strings" + return _shell_result( + text, is_error=True, structured={ - "output": [_shell_output("", "commands must be a list of strings", 1)], + "output": [_shell_output("", text, 1)], "max_output_length": arguments.get("max_output_length"), }, ) + commands = arguments.get("commands") + if isinstance(commands, str): + commands = [commands] + if not isinstance(commands, list): + return invalid_commands_result() + raw_commands = cast("list[Any]", commands) + if not all(isinstance(cmd, str) for cmd in raw_commands): + return invalid_commands_result() + command_list = cast("list[str]", raw_commands) + outputs: list[dict[str, Any]] = [] text_parts: list[str] = [] is_error = False @@ -67,13 +77,12 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR timeout_ms = arguments.get("timeout_ms") if isinstance(timeout_ms, int): env_arguments["timeout_seconds"] = timeout_ms / 1000.0 - for command in commands: - result = await call_tool( - caller, - self.env_tool_name, + for command in command_list: + result = await super().execute( + call_tool, {"command": command, **env_arguments}, ) - text = result_text(result) + text = _result_text(result) if result.isError: outputs.append(_shell_output("", text, 1)) is_error = True @@ -82,8 +91,7 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR if text: text_parts.append(text) - return _provider_result( - "shell", + return _shell_result( "\n".join(text_parts), is_error=is_error, structured={ @@ -92,11 +100,11 @@ async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolR }, ) - def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]: + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> ResponseInputItemParam: structured = result.structuredContent if isinstance(result.structuredContent, dict) else {} output = structured.get("output") if not isinstance(output, list): - output = [_shell_output("", result_text(result), 1 if result.isError else 0)] + output = [_shell_output("", _result_text(result), 1 if result.isError else 0)] response: dict[str, Any] = { "type": "shell_call_output", @@ -107,17 +115,16 @@ def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, A max_output_length = structured.get("max_output_length") if isinstance(max_output_length, int): response["max_output_length"] = max_output_length - return response + return cast("ResponseInputItemParam", response) -def _provider_result( - provider_tool: str, +def _shell_result( text: str, *, is_error: bool = False, structured: dict[str, Any] | None = None, ) -> MCPToolResult: - payload = {"provider_tool": provider_tool, **(structured or {})} + payload = {"provider_tool": "shell", **(structured or {})} return MCPToolResult( content=[TextContent(type="text", text=text)] if text else [], isError=is_error, @@ -125,15 +132,14 @@ def _provider_result( ) +def _result_text(result: MCPToolResult) -> str: + parts = [block.text for block in result.content if isinstance(block, TextContent)] + return "\n".join(part for part in parts if part) + + def _shell_output(stdout: str, stderr: str, exit_code: int) -> dict[str, Any]: return { "stdout": stdout, "stderr": stderr, "outcome": {"type": "exit", "exit_code": exit_code}, } - - -__all__ = [ - "OPENAI_SHELL_SPEC", - "OpenAIShellTool", -] diff --git a/hud/agents/openai/tools/computer.py b/hud/agents/openai/tools/computer.py index acfb39cbe..748a31601 100644 --- a/hud/agents/openai/tools/computer.py +++ b/hud/agents/openai/tools/computer.py @@ -4,14 +4,29 @@ from typing import TYPE_CHECKING, Any, cast -from mcp.types import ImageContent, TextContent +from mcp.types import TextContent +from openai.types.responses.response_input_param import ComputerCallOutput -from hud.types import MCPToolResult +from hud.agents.tools.computer import ( + computer_error_result, + execute_computer_calls, + last_image_data, +) +from hud.types import MCPToolCall, MCPToolResult -from .base import CallTool, OpenAITool, OpenAIToolSpec, call_tool +from .base import OpenAITool, OpenAIToolSpec if TYPE_CHECKING: - from openai.types.responses import ComputerToolParam + from openai.types.responses import ( + ComputerToolParam, + ResponseComputerToolCallOutputScreenshotParam, + ResponseInputItemParam, + ) + from openai.types.responses.response_input_param import ( + ComputerCallOutputAcknowledgedSafetyCheck, + ) + + from hud.agents.tools.base import CallTool else: ComputerToolParam = Any @@ -88,59 +103,92 @@ def __init__( def to_params(self) -> ComputerToolParam: return cast("ComputerToolParam", {"type": "computer"}) - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> ResponseInputItemParam: + screenshot = last_image_data(result) + if not screenshot: + raise ValueError( + "Computer tool result missing screenshot. " + "The tool must always return a screenshot for computer_call_output." + ) + + output = ComputerCallOutput( + type="computer_call_output", + call_id=call.id, + output=cast( + "ResponseComputerToolCallOutputScreenshotParam", + { + "type": "computer_screenshot", + "image_url": f"data:image/png;base64,{screenshot}", + "detail": "original", + }, + ), + ) + + checks = (call.model_extra or {}).get("pending_safety_checks") + if isinstance(checks, list): + acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] = [] + for raw_check in cast("list[Any]", checks): + check: Any = raw_check + if hasattr(check, "model_dump"): + acknowledged.append( + cast("ComputerCallOutputAcknowledgedSafetyCheck", check.model_dump()) + ) + elif isinstance(check, dict): + acknowledged.append(cast("ComputerCallOutputAcknowledgedSafetyCheck", check)) + if acknowledged: + output["acknowledged_safety_checks"] = acknowledged + return cast("ResponseInputItemParam", output) + + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: actions = arguments.get("actions") if isinstance(actions, list): - if not actions: - return _error_result("actions list is empty") + action_list = cast("list[Any]", actions) + if not action_list: + return computer_error_result("actions list is empty") result = MCPToolResult(content=[], isError=False) - for index, action in enumerate(actions): - if not isinstance(action, dict): - return _error_result("actions must be objects") + for index, raw_action in enumerate(action_list): + action = cast("dict[str, Any]", raw_action) + if not isinstance(raw_action, dict): + return computer_error_result("actions must be objects") result = await self._execute_one( - caller, + call_tool, action, - ensure_screenshot=index == len(actions) - 1, + ensure_screenshot=index == len(action_list) - 1, ) if result.isError: return result return result - return await self._execute_one(caller, arguments, ensure_screenshot=True) + return await self._execute_one(call_tool, arguments, ensure_screenshot=True) async def _execute_one( self, - caller: CallTool, + call_tool: CallTool, arguments: dict[str, Any], *, ensure_screenshot: bool, ) -> MCPToolResult: action_type = arguments.get("type") if not isinstance(action_type, str): - return _error_result("type is required") + return computer_error_result("type is required") if action_type == "response": text = arguments.get("text") if not isinstance(text, str): - return _error_result("text is required for response") + return computer_error_result("text is required for response") return MCPToolResult(content=[TextContent(type="text", text=text)], isError=False) env_arguments = self._env_arguments(arguments) - result = await call_tool(caller, self.env_tool_name, env_arguments) - if ( - ensure_screenshot - and action_type in _SCREENSHOT_ACTIONS - and action_type != "screenshot" - and not _has_image(result) - and not result.isError - ): - screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"}) - if not screenshot.isError and screenshot.content: - result = MCPToolResult( - content=[*result.content, *screenshot.content], - isError=result.isError, - ) - return result + return await execute_computer_calls( + call_tool, + env_tool_name=self.env_tool_name, + calls=[env_arguments], + ensure_screenshot=( + ensure_screenshot + and action_type in _SCREENSHOT_ACTIONS + and action_type != "screenshot" + ), + ) def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: action_type = arguments.get("type") @@ -148,11 +196,18 @@ def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: if action_type == "screenshot": return {"action": "screenshot"} if action_type == "click": + button = arguments.get("button") + if button == "wheel": + button_name = "middle" + elif isinstance(button, str): + button_name = button + else: + button_name = "left" return { "action": "click", "x": arguments.get("x"), "y": arguments.get("y"), - "button": _map_button(arguments.get("button")), + "button": button_name, "hold_keys": _hold_keys(arguments.get("keys")), } if action_type == "double_click": @@ -187,7 +242,10 @@ def _env_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: keys = arguments.get("keys") if not isinstance(keys, list): keys = [] - return {"action": "press", "keys": [_map_key(str(key)) for key in keys]} + return { + "action": "press", + "keys": [_map_key(str(key)) for key in cast("list[Any]", keys)], + } if action_type == "drag": return { "action": "drag", @@ -207,24 +265,4 @@ def _map_key(key: str) -> str: def _hold_keys(keys: Any) -> list[str] | None: if not isinstance(keys, list): return None - return [_map_key(str(key)) for key in keys] - - -def _map_button(button: Any) -> str: - if button == "wheel": - return "middle" - return button if isinstance(button, str) else "left" - - -def _has_image(result: MCPToolResult) -> bool: - return any(isinstance(block, ImageContent) for block in result.content) - - -def _error_result(message: str) -> MCPToolResult: - return MCPToolResult( - content=[TextContent(type="text", text=message)], - isError=True, - ) - - -__all__ = ["OPENAI_COMPUTER_SPEC", "OpenAIComputerTool"] + return [_map_key(str(key)) for key in cast("list[Any]", keys)] diff --git a/hud/agents/openai/tools/hosted.py b/hud/agents/openai/tools/hosted.py index 0f13be9ba..b182bd93d 100644 --- a/hud/agents/openai/tools/hosted.py +++ b/hud/agents/openai/tools/hosted.py @@ -45,10 +45,3 @@ class OpenAIToolSearchTool(OpenAIHostedTool): def to_params(self) -> ToolParam: return cast("ToolParam", {"type": "tool_search"}) - - -__all__ = [ - "OpenAICodeInterpreterTool", - "OpenAIHostedTool", - "OpenAIToolSearchTool", -] diff --git a/hud/agents/openai_compatible/__init__.py b/hud/agents/openai_compatible/__init__.py index 3cecd79d2..fc9746f1c 100644 --- a/hud/agents/openai_compatible/__init__.py +++ b/hud/agents/openai_compatible/__init__.py @@ -1,6 +1,5 @@ """OpenAI-compatible agent harness support.""" from .agent import OpenAIChatAgent -from .tools import openai_compatible_tools -__all__ = ["OpenAIChatAgent", "openai_compatible_tools"] +__all__ = ["OpenAIChatAgent"] diff --git a/hud/agents/openai_compatible/agent.py b/hud/agents/openai_compatible/agent.py index 74a464459..5c2351e50 100644 --- a/hud/agents/openai_compatible/agent.py +++ b/hud/agents/openai_compatible/agent.py @@ -18,52 +18,37 @@ import json import logging -from typing import TYPE_CHECKING, Any, ClassVar, cast +from functools import cached_property +from typing import Any, cast import mcp.types as types from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from hud.agents.base import MCPAgent -from hud.agents.tools import ( - AgentTool, - EnvironmentCapability, - call_agent_tools, - capabilities_metadata_from_context, - discover_environment_capabilities, -) -from hud.agents.types import OpenAIChatConfig, OpenAIChatCreateParams +from hud.agents.types import OpenAIChatConfig from hud.settings import settings -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult -from hud.utils.hud_console import HUDConsole +from hud.types import AgentResponse, MCPToolCall from hud.utils.types import with_signature -from .tools import OpenAICompatibleToolParam, openai_compatible_tools - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionToolParam - +from .tools import ( + OpenAICompatibleAgentTools, +) logger = logging.getLogger(__name__) -class OpenAIChatAgent(MCPAgent): +class OpenAIChatAgent(MCPAgent[ChatCompletionMessageParam]): """MCP-enabled agent that speaks the OpenAI *chat.completions* protocol.""" - metadata: ClassVar[dict[str, Any] | None] = None - config_cls: ClassVar[type[BaseAgentConfig]] = OpenAIChatConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for OpenAI-compatible agents.""" - return AgentType.OPENAI_COMPATIBLE - - @with_signature(OpenAIChatCreateParams) + @with_signature(OpenAIChatConfig) @classmethod def create(cls, **kwargs: Any) -> OpenAIChatAgent: # pyright: ignore[reportIncompatibleMethodOverride] - return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value] + return cls(OpenAIChatConfig(**kwargs)) - def __init__(self, params: OpenAIChatCreateParams | None = None, **kwargs: Any) -> None: - super().__init__(params, **kwargs) + def __init__(self, config: OpenAIChatConfig | None = None) -> None: + config = config or OpenAIChatConfig() + super().__init__(config) self.config: OpenAIChatConfig if ( @@ -100,76 +85,25 @@ def __init__(self, params: OpenAIChatCreateParams | None = None, **kwargs: Any) # If a specific checkpoint is requested, inject it into extra_body # so the HUD gateway routes to the exact checkpoint for inference. if self.config.checkpoint: - extra_body = self.completion_kwargs.get("extra_body") or {} + extra_body: dict[str, Any] = dict(self.completion_kwargs.get("extra_body") or {}) extra_body["checkpoint"] = self.config.checkpoint self.completion_kwargs["extra_body"] = extra_body - self.mcp_schemas: list[ChatCompletionToolParam] = [] - self.hud_console = HUDConsole(logger=logger) - self._openai_compatible_tool_params: list[OpenAICompatibleToolParam] = [] - self._openai_compatible_native_tools: dict[ - str, - AgentTool[OpenAICompatibleToolParam], - ] = {} - self._environment_capabilities: dict[str, EnvironmentCapability] = {} - self._openai_compatible_backing_tools: set[str] = set() - self._continuation_token_ids: list[int] | None = None self._continuation_message_count: int | None = None - def _on_tools_ready(self) -> None: - self._convert_tools_for_openai_compatible() - - def _discover_environment_capabilities( - self, tools: list[types.Tool] - ) -> dict[str, EnvironmentCapability]: - return discover_environment_capabilities( - tools, - env_metadata=capabilities_metadata_from_context(self.ctx), - name_fallbacks=openai_compatible_tools.name_fallbacks, - ) - - def _convert_tools_for_openai_compatible(self) -> None: - """Build OpenAI-compatible native tool mappings from environment capabilities.""" - self._openai_compatible_tool_params = [] - self._openai_compatible_native_tools = {} - self._openai_compatible_backing_tools = set() - - capabilities = self._discover_environment_capabilities(self.get_available_tools()) - self._environment_capabilities = capabilities - - for capability in capabilities.values(): - if capability.name not in openai_compatible_tools.capabilities: - continue - for tool in openai_compatible_tools.tools_for_capability(capability, self.model): - self._openai_compatible_backing_tools.add(tool.env_tool_name) - self._openai_compatible_native_tools[tool.name] = tool - self._openai_compatible_tool_params.append(tool.to_params()) - - def _oai_to_mcp(self, tool_call: Any) -> MCPToolCall: # type: ignore[valid-type] - """Convert an OpenAI ``tool_call`` to :class:`MCPToolCall`.""" - args = json.loads(tool_call.function.arguments or "{}") - if isinstance(args, list): - args = args[0] - if not isinstance(args, dict): - args = {} - return MCPToolCall( - id=tool_call.id, - name=tool_call.function.name, - arguments=args, - ) - - async def get_system_messages(self) -> list[dict[str, Any]]: - """Get system messages for OpenAI.""" - if self.system_prompt is not None: - return [{"role": "system", "content": self.system_prompt}] - else: - return [] - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]: - """Format blocks for OpenAI.""" - content = [] - for block in blocks: + @cached_property + def tools(self) -> OpenAICompatibleAgentTools: + return OpenAICompatibleAgentTools() + + async def format_messages( + self, messages: list[types.PromptMessage] + ) -> list[ChatCompletionMessageParam]: + """Format MCP prompt messages for OpenAI-compatible chat.""" + formatted_messages: list[ChatCompletionMessageParam] = [] + for message in messages: + content: list[dict[str, Any]] = [] + block = message.content if isinstance(block, types.TextContent): content.append({"type": "text", "text": block.text}) elif isinstance(block, types.ImageContent): @@ -180,146 +114,54 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str } ) - return [{"role": "user", "content": content}] - - def _sanitize_schema_for_openai(self, schema: dict) -> dict: - """Convert MCP JSON Schema to OpenAI-compatible format. - - Handles unsupported features like anyOf and prefixItems. - """ - if not isinstance(schema, dict): - return schema - - sanitized = {} - - for key, value in schema.items(): - if key == "anyOf" and isinstance(value, list): - # Handle anyOf patterns (usually for nullable fields) - non_null_types = [ - v for v in value if not (isinstance(v, dict) and v.get("type") == "null") - ] - if non_null_types: - # Use the first non-null type - sanitized.update(self._sanitize_schema_for_openai(non_null_types[0])) - else: - sanitized["type"] = "string" # Fallback - - elif key == "prefixItems": - # Convert prefixItems to simple items - sanitized["type"] = "array" - if isinstance(value, list) and value: - # Use the type from the first item as the items schema - first_item = value[0] - if isinstance(first_item, dict): - sanitized["items"] = {"type": first_item.get("type", "string")} - else: - sanitized["items"] = {"type": "string"} - - elif key == "properties" and isinstance(value, dict): - # Recursively sanitize property schemas - sanitized[key] = { - prop_name: self._sanitize_schema_for_openai(prop_schema) - for prop_name, prop_schema in value.items() - } + formatted_messages.append( + cast( + "ChatCompletionMessageParam", + {"role": message.role, "content": content}, + ) + ) + return formatted_messages - elif key == "items" and isinstance(value, dict): - # Recursively sanitize items schema - sanitized[key] = self._sanitize_schema_for_openai(value) - - elif key in ( - "type", - "description", - "enum", - "required", - "default", - "minimum", - "maximum", - "minItems", - "maxItems", - ): - # These are supported by OpenAI - sanitized[key] = value - - return sanitized or {"type": "object"} - - def get_tool_schemas(self) -> list[OpenAICompatibleToolParam]: - tool_schemas = [ - schema - for schema in super().get_tool_schemas() - if schema["name"] not in self._openai_compatible_backing_tools - ] - openai_tools = list(self._openai_compatible_tool_params) - for schema in tool_schemas: - parameters = schema.get("parameters", {}) - - if parameters: - sanitized_params = self._sanitize_schema_for_openai(parameters) - else: - sanitized_params = {"type": "object", "properties": {}} - - openai_tool: ChatCompletionToolParam = { - "type": "function", - "function": { - "name": schema["name"], - "description": schema.get("description", ""), - "parameters": sanitized_params, - }, - } - openai_tools.append(openai_tool) - return openai_tools - - async def call_tools( - self, tool_call: MCPToolCall | list[MCPToolCall] | None = None - ) -> list[MCPToolResult]: - """Route OpenAI-compatible provider tools through agent-owned translators.""" - return await call_agent_tools(self, self._openai_compatible_native_tools, tool_call) - - async def _invoke_chat_completion( - self, - *, - messages: list[Any], - tools: list[dict] | None, - extra: dict[str, Any], - ) -> Any: - if self.oai is None: - raise ValueError("openai_client is required for OpenAIChatAgent") - # default transport = OpenAI SDK - return await self.oai.chat.completions.create( - model=self.config.model, - messages=messages, - tools=tools, # type: ignore ready ChatCompletionToolParam-shaped - **extra, - ) # type: ignore - - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: + async def get_response(self, messages: list[ChatCompletionMessageParam]) -> AgentResponse: """Send chat request to OpenAI and convert the response.""" - # Convert MCP tool schemas to OpenAI format - tools = cast("list[ChatCompletionToolParam]", self.get_tool_schemas()) + reserved_kwargs = {"model", "messages", "stream", "tools"} + request_kwargs = { + key: value + for key, value in self.completion_kwargs.items() + if key not in reserved_kwargs + } + provider_body: dict[str, Any] = dict(request_kwargs.pop("extra_body", None) or {}) + return_token_ids = bool(provider_body.get("return_token_ids")) - protected_keys = {"model", "messages", "tools"} - extra = {k: v for k, v in (self.completion_kwargs or {}).items() if k not in protected_keys} - extra_body = extra.get("extra_body") or {} - return_token_ids = extra_body.get("return_token_ids") + if self.tools.params: + provider_body["tools"] = self.tools.params if return_token_ids and self._continuation_token_ids and self._continuation_message_count: - extra_body["prompt_token_ids"] = self._continuation_token_ids - extra_body["continuation_from"] = self._continuation_message_count - extra["extra_body"] = extra_body + provider_body["prompt_token_ids"] = self._continuation_token_ids + provider_body["continuation_from"] = self._continuation_message_count + + if provider_body: + request_kwargs["extra_body"] = provider_body try: - response = await self._invoke_chat_completion( - messages=messages, - tools=tools, # type: ignore - extra=extra, + response: ChatCompletion = await self.oai.chat.completions.create( + model=self.config.model, + messages=( + [{"role": "system", "content": self.system_prompt}, *messages] + if self.system_prompt is not None + else messages + ), + stream=False, + **request_kwargs, ) except Exception as e: error_content = f"Error getting response {e}" if "Invalid JSON" in str(e): error_content = "Invalid JSON, response was truncated" - self.hud_console.warning_log(error_content) + logger.warning(error_content) - return InferenceResult( + return AgentResponse( content=error_content, tool_calls=[], done=True, @@ -328,24 +170,33 @@ async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: ) choice = response.choices[0] - msg = choice.message - assistant_msg: dict[str, Any] = {"role": "assistant"} - - if msg.content: - assistant_msg["content"] = msg.content + message = choice.message + function_calls = [ + tool_call for tool_call in message.tool_calls or [] if tool_call.type == "function" + ] - if msg.tool_calls: - serialized_tool_calls = [] - for tc in msg.tool_calls: - serialized_tc = { - "id": tc.id, + assistant_message = message.model_dump(exclude_none=True) + reasoning_content = getattr(message, "reasoning_content", None) + reasoning = reasoning_content if isinstance(reasoning_content, str) else None + if not reasoning: + raw_reasoning = getattr(message, "reasoning", None) + reasoning = raw_reasoning if isinstance(raw_reasoning, str) else None + for field in ("reasoning_content", "reasoning", "reasoning_details"): + if value := getattr(message, field, None): + assistant_message[field] = value + if function_calls: + assistant_message["tool_calls"] = [ + { + "id": tool_call.id, "type": "function", - "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - serialized_tool_calls.append(serialized_tc) - assistant_msg["tool_calls"] = serialized_tool_calls - - messages.append(assistant_msg) + for tool_call in function_calls + ] + messages.append(cast("ChatCompletionMessageParam", assistant_message)) if return_token_ids: prompt_token_ids = getattr(choice, "prompt_token_ids", None) @@ -354,91 +205,23 @@ async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: self._continuation_token_ids = list(prompt_token_ids) + list(token_ids) self._continuation_message_count = len(messages) - tool_calls = [] - if msg.tool_calls: - for tc in msg.tool_calls: - if tc.function.name is not None: # type: ignore - # _oai_to_mcp returns a single MCPToolCall; append it - tool_calls.append(self._oai_to_mcp(tc)) # noqa: PERF401 - - # Only stop on length (token limit), never on "stop" - done = choice.finish_reason == "length" - if done: - self.hud_console.info_log(f"Done decision: finish_reason={choice.finish_reason}") - - return InferenceResult( - content=msg.content or "", - reasoning=getattr(msg, "reasoning_content", None), + tool_calls: list[MCPToolCall] = [] + for tool_call in function_calls: + raw_args = json.loads(tool_call.function.arguments or "{}") + arguments = cast("dict[str, Any]", raw_args) if isinstance(raw_args, dict) else {} + tool_calls.append( + MCPToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=arguments, + ) + ) + + return AgentResponse( + content=message.content or "", + reasoning=reasoning, + info={"finish_reason": choice.finish_reason}, tool_calls=tool_calls, - done=done, + done=not tool_calls, raw=response, ) - - async def format_tool_results( - self, - tool_calls: list[MCPToolCall], - tool_results: list[MCPToolResult], - ) -> list[dict[str, Any]]: - """Render MCP tool results as OpenAI messages. - - Note: OpenAI tool messages only support string content. - When images are present, we return both a tool message and a user message. - """ - rendered: list[dict[str, Any]] = [] - - # Separate text and image content - image_parts = [] - for call, res in zip(tool_calls, tool_results, strict=False): - # Use structuredContent.result if available, otherwise use content - text_parts = [] - items = res.content - if not res.content and res.structuredContent: - items = [res.structuredContent.get("result", res.content)] - - for item in items: - if isinstance(item, dict): - if item.get("type") == "text": - text_parts.append(item.get("text", "")) - elif item.get("type") == "image": - mime_type = item.get("mimeType", "image/png") - data = item.get("data", "") - image_parts.append( - { - "type": "image_url", - "image_url": {"url": f"data:{mime_type};base64,{data}"}, - } - ) - elif isinstance(item, types.TextContent): - text_parts.append(item.text) - elif isinstance(item, types.ImageContent): - image_parts.append( - { - "type": "image_url", - "image_url": {"url": f"data:{item.mimeType};base64,{item.data}"}, - } - ) - - text_content = "".join(text_parts) if text_parts else "Tool executed successfully" - rendered.append( - { - "role": "tool", - "tool_call_id": call.id, - "content": text_content, - } - ) - - # If there are images, add them as a separate user message - if image_parts: - # Add a user message with the images - content_with_images = [ - {"type": "text", "text": "Tool returned the following:"}, - image_parts[-1], - ] - rendered.append( - { - "role": "user", - "content": content_with_images, - } - ) - - return rendered diff --git a/hud/agents/openai_compatible/tools/__init__.py b/hud/agents/openai_compatible/tools/__init__.py index 94f800b76..1c408f184 100644 --- a/hud/agents/openai_compatible/tools/__init__.py +++ b/hud/agents/openai_compatible/tools/__init__.py @@ -2,31 +2,33 @@ from __future__ import annotations -from dataclasses import dataclass, field +from typing import TYPE_CHECKING, ClassVar -from hud.agents.tools import AgentTool, AgentToolRegistry +from hud.agents.tools import AgentTool, AgentTools -from .computer import ( - GLM_COMPUTER_SPEC, - QWEN_COMPUTER_SPEC, - GLMComputerTool, - QwenComputerTool, +from .base import ( + OpenAICompatibleFunctionTool, + OpenAICompatibleToolParam, ) from .filesystem import ( - FilesystemTool, GlobTool, GrepTool, ListTool, ReadTool, ) -from .types import OpenAICompatibleToolParam +from .glm_computer import GLMComputerTool +from .qwen_computer import QwenComputerTool +if TYPE_CHECKING: + from collections.abc import Mapping -@dataclass(frozen=True) -class OpenAICompatibleToolRegistry(AgentToolRegistry[AgentTool[OpenAICompatibleToolParam]]): - """Registry for OpenAI-compatible harness tools.""" - tool_classes: tuple[type[AgentTool[OpenAICompatibleToolParam]], ...] = ( +class OpenAICompatibleAgentTools( + AgentTools[AgentTool[OpenAICompatibleToolParam], OpenAICompatibleToolParam] +): + """Prepared OpenAI-compatible chat tool state for a run.""" + + native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = ( GLMComputerTool, QwenComputerTool, ReadTool, @@ -34,43 +36,19 @@ class OpenAICompatibleToolRegistry(AgentToolRegistry[AgentTool[OpenAICompatibleT GlobTool, ListTool, ) - name_fallbacks: dict[str, tuple[str, ...]] = field( - default_factory=lambda: { - "computer": ( - "computer", - "hud_computer", - "openai_computer", - "glm_computer", - "qwen_computer", - ), - "filesystem": ("read", "grep", "glob", "list"), - } - ) - - @property - def api_types(self) -> frozenset[str]: - api_types: set[str] = set() - for cls in self.tool_classes: - spec = cls.default_spec("unknown") - if spec is not None and spec.api_type != "function": - api_types.add(spec.api_type) - api_types.update(getattr(cls, "ignored_api_types", frozenset())) - return frozenset(api_types) - + function_tool_class = OpenAICompatibleFunctionTool + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = { + "computer": ( + "computer", + "hud_computer", + "openai_computer", + "glm_computer", + "qwen_computer", + ), + "filesystem": ("read", "grep", "glob", "list"), + } -openai_compatible_tools = OpenAICompatibleToolRegistry() __all__ = [ - "GLM_COMPUTER_SPEC", - "QWEN_COMPUTER_SPEC", - "FilesystemTool", - "GLMComputerTool", - "GlobTool", - "GrepTool", - "ListTool", - "OpenAICompatibleToolParam", - "OpenAICompatibleToolRegistry", - "QwenComputerTool", - "ReadTool", - "openai_compatible_tools", + "OpenAICompatibleAgentTools", ] diff --git a/hud/agents/openai_compatible/tools/base.py b/hud/agents/openai_compatible/tools/base.py new file mode 100644 index 000000000..2d11866be --- /dev/null +++ b/hud/agents/openai_compatible/tools/base.py @@ -0,0 +1,180 @@ +"""OpenAI-compatible agent-owned tool setup.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeAlias, cast + +import mcp.types as mcp_types + +from hud.agents.tools import AgentTool, AgentToolSpec + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam + + from hud.types import MCPToolCall, MCPToolResult + + from .qwen_computer import QwenComputerUseToolParam + +OpenAICompatibleToolParam: TypeAlias = "ChatCompletionToolParam | QwenComputerUseToolParam" + + +class OpenAICompatibleTool(AgentTool[OpenAICompatibleToolParam]): + """Agent-side OpenAI-compatible tool backed by an environment tool.""" + + def format_result( + self, call: MCPToolCall, result: MCPToolResult + ) -> ChatCompletionMessageParam | list[ChatCompletionMessageParam]: + text_parts: list[str] = [] + image_parts: list[dict[str, Any]] = [] + items: list[Any] = list(result.content) + if not result.content and result.structuredContent: + items = [result.structuredContent.get("result", result.content)] + + for item in items: + if isinstance(item, dict): + item_dict = cast("dict[str, Any]", item) + if item_dict.get("type") == "text": + text_parts.append(str(item_dict.get("text", ""))) + elif item_dict.get("type") == "image": + mime_type = str(item_dict.get("mimeType", "image/png")) + data = str(item_dict.get("data", "")) + image_parts.append( + { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{data}"}, + } + ) + elif isinstance(item, mcp_types.TextContent): + text_parts.append(item.text) + elif isinstance(item, mcp_types.ImageContent): + image_parts.append( + { + "type": "image_url", + "image_url": {"url": f"data:{item.mimeType};base64,{item.data}"}, + } + ) + + tool_message = cast( + "ChatCompletionMessageParam", + { + "role": "tool", + "tool_call_id": call.id, + "content": "".join(text_parts) if text_parts else "Tool executed successfully", + }, + ) + if not image_parts: + return tool_message + return [ + tool_message, + cast( + "ChatCompletionMessageParam", + { + "role": "user", + "content": [ + {"type": "text", "text": "Tool returned the following:"}, + image_parts[-1], + ], + }, + ), + ] + + +class OpenAICompatibleFunctionTool(OpenAICompatibleTool): + """Regular environment tool exposed as an OpenAI-compatible function.""" + + name = "function" + capability = "function" + + def __init__(self, *, env_tool_name: str, params: OpenAICompatibleToolParam) -> None: + super().__init__( + env_tool_name=env_tool_name, + spec=AgentToolSpec(api_type="function", api_name=env_tool_name), + ) + self.params = params + + @classmethod + def from_tool(cls, tool: mcp_types.Tool) -> OpenAICompatibleFunctionTool: + return cls(env_tool_name=tool.name, params=openai_compatible_tool_param(tool)) + + @property + def provider_name(self) -> str: + return self.env_tool_name + + def to_params(self) -> OpenAICompatibleToolParam: + return self.params + + +def openai_compatible_tool_param(tool: mcp_types.Tool) -> OpenAICompatibleToolParam: + parameters = tool.inputSchema + sanitized_params: dict[str, Any] = ( + _sanitize_schema_for_openai(parameters) + if parameters + else {"type": "object", "properties": {}} + ) + + return cast( + "OpenAICompatibleToolParam", + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or f"Call {tool.name}", + "parameters": sanitized_params, + }, + }, + ) + + +def _sanitize_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]: + """Convert MCP JSON Schema to OpenAI-compatible format.""" + sanitized: dict[str, Any] = {} + + for key, value in schema.items(): + if key == "anyOf" and isinstance(value, list): + any_of_items = cast("list[Any]", value) + non_null_types: list[dict[str, Any]] = [ + cast("dict[str, Any]", item) + for item in any_of_items + if isinstance(item, dict) and cast("dict[str, Any]", item).get("type") != "null" + ] + if non_null_types: + sanitized.update(_sanitize_schema_for_openai(non_null_types[0])) + else: + sanitized["type"] = "string" + + elif key == "prefixItems" and isinstance(value, list): + sanitized["type"] = "array" + prefix_items = cast("list[Any]", value) + if prefix_items: + first_item: Any = prefix_items[0] + if isinstance(first_item, dict): + first_schema = cast("dict[str, Any]", first_item) + sanitized["items"] = {"type": first_schema.get("type", "string")} + else: + sanitized["items"] = {"type": "string"} + + elif key == "properties" and isinstance(value, dict): + properties = cast("dict[str, Any]", value) + sanitized[key] = { + prop_name: _sanitize_schema_for_openai(cast("dict[str, Any]", prop_schema)) + for prop_name, prop_schema in properties.items() + if isinstance(prop_schema, dict) + } + + elif key == "items" and isinstance(value, dict): + sanitized[key] = _sanitize_schema_for_openai(cast("dict[str, Any]", value)) + + elif key in ( + "type", + "description", + "enum", + "required", + "default", + "minimum", + "maximum", + "minItems", + "maxItems", + ): + sanitized[key] = value + + return sanitized or {"type": "object"} diff --git a/hud/agents/openai_compatible/tools/computer.py b/hud/agents/openai_compatible/tools/computer.py deleted file mode 100644 index d7e450c89..000000000 --- a/hud/agents/openai_compatible/tools/computer.py +++ /dev/null @@ -1,566 +0,0 @@ -"""Agent-side OpenAI-compatible computer tools.""" - -from __future__ import annotations - -import logging -import re -from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args - -from mcp.types import ImageContent, TextContent - -from hud.agents.tools import AgentTool, AgentToolSpec, CallTool, call_tool -from hud.tools.computer import computer_settings -from hud.types import MCPToolResult - -from .types import OpenAICompatibleToolParam, QwenComputerUseToolParam - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionToolParam - from openai.types.shared_params.function_parameters import FunctionParameters - - from hud.agents.tools import EnvironmentCapability - -logger = logging.getLogger(__name__) - -GLM_COORDINATE_SPACE = 999 - -GLMAction = Literal[ - "left_click", - "click", - "right_click", - "middle_click", - "hover", - "left_double_click", - "left_drag", - "key", - "type", - "scroll", - "screenshot", - "WAIT", -] - -VALID_GLM_ACTIONS: set[str] = set(get_args(GLMAction)) - -GLM_COMPUTER_SPEC = AgentToolSpec( - api_type="function", - api_name="computer", - supported_models=("glm-*",), -) - -QWEN_COMPUTER_SPEC = AgentToolSpec( - api_type="computer_use", - api_name="computer_use", - supported_models=("qwen*",), -) - -GLM_SYSTEM_INSTRUCTIONS = ( - "You are a GUI Agent. Your task is to respond accurately to user requests by using " - "tools or performing GUI operations until the task is fulfilled. Coordinates are in " - "thousandths (0-999). Complete tasks autonomously without asking for confirmation. " - "If a task cannot be completed, explain the failure in your final response." -) - -GLM_COMPUTER_DESCRIPTION = """\ -Use this tool to interact with the computer via GLM's PC action space. -* Coordinates use a 0-999 normalized scale (thousandths of screen dimensions). -* Always use valid JSON for function arguments. Do NOT use XML tags. - Correct: {"action": "left_click", "start_box": "[500, 300]"} - Wrong: {"action": "left_clickstart_box..."} -* Available actions: - - left_click/right_click/middle_click(start_box='[x,y]') - - hover(start_box='[x,y]'), left_double_click(start_box='[x,y]') - - left_drag(start_box='[x,y]', end_box='[x,y]') - - key(keys='ctrl+c'), type(content='text') - - scroll(start_box='[x,y]', direction='up|down', step=5) - - screenshot(), WAIT() -* If a task cannot be completed, explain the failure in your final response.\ -""".strip() - -GLM_COMPUTER_PARAMETERS: FunctionParameters = { - "type": "object", - "properties": { - "action": { - "type": "string", - "description": ( - "REQUIRED. Action to perform: left_click, right_click, middle_click, " - "hover, left_double_click, left_drag, key, type, scroll, screenshot, " - "WAIT" - ), - "enum": sorted(VALID_GLM_ACTIONS), - }, - "start_box": { - "description": ( - "Position as '[x,y]' string or [x,y] array, coordinates 0-999 normalized" - ), - }, - "end_box": { - "description": "End position for drag as '[x,y]' string or [x,y] array", - }, - "content": {"type": "string", "description": "Text content to type"}, - "keys": {"description": "Key(s) to press, e.g. 'enter', 'ctrl+c', 'alt+tab'"}, - "direction": {"type": "string", "description": "Scroll direction: 'up' or 'down'"}, - "step": {"type": "integer", "description": "Scroll steps", "default": 5}, - "element_info": {"type": "string", "description": "Optional UI element description"}, - }, - "required": ["action"], -} - - -class GLMComputerTool(AgentTool[OpenAICompatibleToolParam]): - """Translate GLM native GUI calls into generic environment computer calls.""" - - name = "computer" - capability = "computer" - ignored_api_types: ClassVar[frozenset[str]] = frozenset({"gui_agent_glm45v"}) - - @classmethod - def default_spec(cls, model: str) -> AgentToolSpec | None: - if GLM_COMPUTER_SPEC.supports_model(model): - return GLM_COMPUTER_SPEC - return None - - def __init__( - self, - *, - env_tool_name: str, - spec: AgentToolSpec, - display_width: int, - display_height: int, - ) -> None: - super().__init__(env_tool_name=env_tool_name, spec=spec) - self.display_width = display_width - self.display_height = display_height - - @classmethod - def from_capability( - cls, - capability: EnvironmentCapability, - spec: AgentToolSpec, - model: str, - ) -> GLMComputerTool: - del model - width, height = _resolution_from_capability( - capability, - default_width=computer_settings.GLM_COMPUTER_WIDTH, - default_height=computer_settings.GLM_COMPUTER_HEIGHT, - ) - return cls( - env_tool_name=capability.tool_name, - spec=spec, - display_width=width, - display_height=height, - ) - - def to_params(self) -> ChatCompletionToolParam: - return { - "type": "function", - "function": { - "name": self.name, - "description": ( - f"{GLM_COMPUTER_DESCRIPTION}\n* The screen's resolution is " - f"{self.display_width}x{self.display_height}." - ), - "parameters": GLM_COMPUTER_PARAMETERS, - }, - } - - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - arguments = _fix_glm_xml_args(arguments) - action = arguments.get("action") - if not isinstance(action, str): - return _error_result("'action' is required") - - result = MCPToolResult(content=[], isError=False) - for call in self._env_calls(action, arguments): - result = await call_tool(caller, self.env_tool_name, call) - if result.isError: - return result - - if action not in {"screenshot", "WAIT"} and not _has_image(result): - screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"}) - if not screenshot.isError and screenshot.content: - result = MCPToolResult( - content=[*result.content, *screenshot.content], - isError=result.isError, - ) - return result - - def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: - start = _parse_glm_box(arguments.get("start_box")) - end = _parse_glm_box(arguments.get("end_box")) - - if action == "screenshot": - return [{"action": "screenshot"}] - if action == "WAIT": - return [{"action": "wait", "time": 5000}] - if action in ("left_click", "click", "right_click", "middle_click"): - x, y = self._point(start, f"start_box required for {action}") - button = { - "left_click": "left", - "click": "left", - "right_click": "right", - "middle_click": "middle", - }[action] - return [{"action": "click", "x": x, "y": y, "button": button}] - if action == "hover": - x, y = self._point(start, "start_box required for hover") - return [{"action": "move", "x": x, "y": y}] - if action == "left_double_click": - x, y = self._point(start, "start_box required for left_double_click") - return [{"action": "click", "x": x, "y": y, "button": "left", "pattern": [100]}] - if action == "left_drag": - start_x, start_y = self._point(start, "start_box required for left_drag") - end_x, end_y = self._point(end, "end_box required for left_drag") - return [ - { - "action": "drag", - "path": [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}], - } - ] - if action == "key": - keys = _parse_glm_keys(arguments.get("keys")) - if not keys: - raise ValueError("keys required for key action") - return [{"action": "press", "keys": keys}] - if action == "type": - content = arguments.get("content") - if not isinstance(content, str) or not content: - raise ValueError("content required for type") - return [{"action": "write", "text": content, "enter_after": False}] - if action == "scroll": - direction = arguments.get("direction") - if direction not in {"up", "down"}: - raise ValueError("direction must be 'up' or 'down'") - point = start or (GLM_COORDINATE_SPACE // 2, GLM_COORDINATE_SPACE // 2) - x, y = self._scale_normalized_point(point) - step = arguments.get("step") or 5 - scroll_y = int(step) * 100 if direction == "down" else -int(step) * 100 - return [{"action": "scroll", "x": x, "y": y, "scroll_y": scroll_y}] - raise ValueError(f"Unknown action: {action}") - - def _point(self, point: tuple[int, int] | None, message: str) -> tuple[int, int]: - if point is None: - raise ValueError(message) - return self._scale_normalized_point(point) - - def _scale_normalized_point(self, point: tuple[int, int]) -> tuple[int, int]: - x, y = point - scaled_x = round(x / GLM_COORDINATE_SPACE * (self.display_width - 1)) - scaled_y = round(y / GLM_COORDINATE_SPACE * (self.display_height - 1)) - return scaled_x, scaled_y - - -class QwenComputerTool(AgentTool[OpenAICompatibleToolParam]): - """Translate Qwen computer_use calls into generic environment computer calls.""" - - name = "computer_use" - capability = "computer" - - @classmethod - def default_spec(cls, model: str) -> AgentToolSpec | None: - if QWEN_COMPUTER_SPEC.supports_model(model): - return QWEN_COMPUTER_SPEC - return None - - def __init__( - self, - *, - env_tool_name: str, - spec: AgentToolSpec, - display_width: int, - display_height: int, - description: str, - ) -> None: - super().__init__(env_tool_name=env_tool_name, spec=spec) - self.display_width = display_width - self.display_height = display_height - self.description = description - - @classmethod - def from_capability( - cls, - capability: EnvironmentCapability, - spec: AgentToolSpec, - model: str, - ) -> QwenComputerTool: - del model - width, height = _resolution_from_capability( - capability, - default_width=computer_settings.QWEN_COMPUTER_WIDTH, - default_height=computer_settings.QWEN_COMPUTER_HEIGHT, - ) - return cls( - env_tool_name=capability.tool_name, - spec=spec, - display_width=width, - display_height=height, - description=_qwen_description(width, height), - ) - - def to_params(self) -> QwenComputerUseToolParam: - tool: QwenComputerUseToolParam = { - "type": "computer_use", - "name": self.name, - "display_width_px": self.display_width, - "display_height_px": self.display_height, - "description": self.description, - "parameters": QWEN_COMPUTER_PARAMETERS, - } - return tool - - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - action = arguments.get("action") - if not isinstance(action, str): - return _error_result("action is required") - if action == "terminate": - return _error_result("terminate action is not supported for computer control.") - if action == "answer": - return _error_result("answer action is not supported for computer control.") - - result = MCPToolResult(content=[], isError=False) - for call in self._env_calls(action, arguments): - result = await call_tool(caller, self.env_tool_name, call) - if result.isError: - return result - - if action not in {"screenshot", "wait"} and not _has_image(result): - screenshot = await call_tool(caller, self.env_tool_name, {"action": "screenshot"}) - if not screenshot.isError and screenshot.content: - result = MCPToolResult( - content=[*result.content, *screenshot.content], - isError=result.isError, - ) - return result - - def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: - coordinate = _parse_qwen_coordinate(arguments.get("coordinate")) - if action == "screenshot": - return [{"action": "screenshot"}] - if action in {"left_click", "right_click", "middle_click"}: - x, y = _required_coordinate(coordinate, action) - button = {"left_click": "left", "right_click": "right", "middle_click": "middle"}[ - action - ] - return [{"action": "click", "x": x, "y": y, "button": button}] - if action == "double_click": - x, y = _required_coordinate(coordinate, action) - return [{"action": "click", "x": x, "y": y, "pattern": [100]}] - if action == "triple_click": - x, y = _required_coordinate(coordinate, action) - return [{"action": "click", "x": x, "y": y, "pattern": [100, 100]}] - if action == "mouse_move": - x, y = _required_coordinate(coordinate, action) - return [{"action": "move", "x": x, "y": y}] - if action == "type": - text = arguments.get("text") - if not isinstance(text, str): - raise ValueError("text is required for type") - return [{"action": "write", "text": text}] - if action == "key": - keys = arguments.get("keys") - if not isinstance(keys, list): - raise ValueError("keys is required for key") - return [{"action": "press", "keys": keys}] - if action in {"scroll", "hscroll"}: - pixels = arguments.get("pixels") - if not isinstance(pixels, int | float): - raise ValueError("pixels is required for scroll") - call: dict[str, Any] = {"action": "scroll"} - if coordinate is not None: - call.update({"x": coordinate[0], "y": coordinate[1]}) - if action == "scroll": - call["scroll_y"] = -int(pixels) - else: - call["scroll_x"] = int(pixels) - return [call] - if action == "left_click_drag": - x, y = _required_coordinate(coordinate, action) - return [ - {"action": "mouse_down", "button": "left"}, - {"action": "move", "x": x, "y": y}, - {"action": "mouse_up", "button": "left"}, - ] - if action == "wait": - time = arguments.get("time") - if not isinstance(time, int | float): - raise ValueError("time is required for wait") - if time < 0: - raise ValueError("time must be non-negative") - return [{"action": "wait", "time": int(time * 1000)}] - raise ValueError(f"Invalid action: {action}") - - -QWEN_COMPUTER_PARAMETERS: FunctionParameters = { - "properties": { - "action": { - "description": """ -The action to perform. The available actions are: -* `key`: Performs key down presses on the arguments passed in order, then performs -key releases in reverse order. -* `type`: Type a string of text on the keyboard. -* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen. -* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate. -* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate. -* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate. -* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate. -* `double_click`: Double-click the left mouse button. -* `triple_click`: Triple-click the left mouse button. -* `scroll`: Performs a vertical scroll. -* `hscroll`: Performs a horizontal scroll. -* `wait`: Wait specified seconds for the change to happen. -""".strip(), - "enum": [ - "key", - "type", - "mouse_move", - "left_click", - "left_click_drag", - "right_click", - "middle_click", - "double_click", - "triple_click", - "scroll", - "hscroll", - "wait", - ], - "type": "string", - }, - "keys": {"description": "Required only by `action=key`.", "type": "array"}, - "text": { - "description": "Required only by `action=type`.", - "type": "string", - }, - "coordinate": { - "description": "(x, y) pixel coordinate to interact with.", - "type": "array", - }, - "pixels": { - "description": "Scroll amount. Positive vertical values scroll up.", - "type": "number", - }, - "time": { - "description": "Seconds to wait. Required only by `action=wait`.", - "type": "number", - }, - }, - "required": ["action"], - "type": "object", -} - - -def _resolution_from_capability( - capability: EnvironmentCapability, - *, - default_width: int, - default_height: int, -) -> tuple[int, int]: - metadata_resolution = capability.metadata.get("resolution", {}) - if not isinstance(metadata_resolution, dict): - metadata_resolution = {} - tool_resolution = (capability.tool.meta or {}).get("resolution", {}) - if not isinstance(tool_resolution, dict): - tool_resolution = {} - width = int(metadata_resolution.get("width") or tool_resolution.get("width") or default_width) - height = int( - metadata_resolution.get("height") or tool_resolution.get("height") or default_height - ) - return width, height - - -def _qwen_description(width: int, height: int) -> str: - return f""" -Use a mouse and keyboard to interact with a computer, and take screenshots. -* This is an interface to a desktop GUI. You do not have access to a terminal or -applications menu. You must click on desktop icons to start applications. -* Some applications may take time to start or process actions, so you may need to -wait and take successive screenshots to see the results of your actions. -* The screen's resolution is {width}x{height}. -* Whenever you intend to move the cursor to click on an element like an icon, you -should consult a screenshot to determine the coordinates of the element before -moving the cursor. -* Make sure to click buttons, links, and icons with the cursor tip in the center. -""".strip() - - -def _parse_glm_box(box: Any) -> tuple[int, int] | None: - if box is None: - return None - if isinstance(box, str): - match = re.match(r"\[?\s*(\d+)\s*,\s*(\d+)\s*\]?", box.strip()) - if match: - return int(match.group(1)), int(match.group(2)) - return None - if isinstance(box, list): - if len(box) == 1 and isinstance(box[0], list): - box = box[0] - if len(box) >= 2: - try: - return int(box[0]), int(box[1]) - except (TypeError, ValueError): - return None - return None - - -def _parse_glm_keys(keys: Any) -> list[str]: - if not keys: - return [] - if isinstance(keys, list): - return [str(key).strip().lower() for key in keys] - return [key.strip().lower() for key in str(keys).split("+") if key.strip()] - - -def _fix_glm_xml_args(args: dict[str, Any]) -> dict[str, Any]: - fixed: dict[str, Any] = {} - for key, value in args.items(): - if not isinstance(value, str) or not re.search(r"(\w+)\s*([^\"<]+)", value) - for arg_name, arg_val in matches: - if arg_name and arg_val: - fixed[arg_name.strip()] = arg_val.strip() - - if not main_value and not matches: - fixed[key] = value - logger.warning("Fixed GLM XML args: %s -> %s", args, fixed) - return fixed - - -def _parse_qwen_coordinate(coordinate: Any) -> tuple[int, int] | None: - if isinstance(coordinate, list | tuple) and len(coordinate) >= 2: - try: - return int(coordinate[0]), int(coordinate[1]) - except (TypeError, ValueError): - return None - return None - - -def _required_coordinate(coordinate: tuple[int, int] | None, action: str) -> tuple[int, int]: - if coordinate is None: - raise ValueError(f"coordinate is required for {action}") - return coordinate - - -def _has_image(result: MCPToolResult) -> bool: - return any(isinstance(block, ImageContent) for block in result.content) - - -def _error_result(message: str) -> MCPToolResult: - return MCPToolResult(content=[TextContent(type="text", text=message)], isError=True) - - -__all__ = [ - "GLM_COMPUTER_SPEC", - "GLM_COORDINATE_SPACE", - "QWEN_COMPUTER_SPEC", - "VALID_GLM_ACTIONS", - "GLMComputerTool", - "QwenComputerTool", - "_fix_glm_xml_args", - "_parse_glm_box", -] diff --git a/hud/agents/openai_compatible/tools/filesystem.py b/hud/agents/openai_compatible/tools/filesystem.py index 4f5ba57f2..a09ed988c 100644 --- a/hud/agents/openai_compatible/tools/filesystem.py +++ b/hud/agents/openai_compatible/tools/filesystem.py @@ -4,84 +4,16 @@ from typing import TYPE_CHECKING, ClassVar -from hud.agents.tools import AgentTool, AgentToolSpec, GroupedCapabilityMixin +from hud.agents.tools import AgentToolSpec, GroupedCapabilityMixin -from .types import OpenAICompatibleToolParam +from .base import OpenAICompatibleTool if TYPE_CHECKING: from openai.types.chat import ChatCompletionToolParam from openai.types.shared_params.function_parameters import FunctionParameters -READ_PARAMETERS: FunctionParameters = { - "type": "object", - "properties": { - "filePath": { - "type": "string", - "description": "Absolute path to the file to read.", - }, - "offset": { - "type": "integer", - "description": "0-based line offset to start reading from.", - }, - "limit": { - "type": "integer", - "description": "Maximum number of lines to read.", - }, - }, - "required": ["filePath"], -} - -GREP_PARAMETERS: FunctionParameters = { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Regular expression pattern to search for.", - }, - "path": { - "type": "string", - "description": "Directory to search in.", - }, - "include": { - "type": "string", - "description": "Glob pattern for files to include.", - }, - }, - "required": ["pattern"], -} - -GLOB_PARAMETERS: FunctionParameters = { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern to match.", - }, - "path": { - "type": "string", - "description": "Directory to search from.", - }, - }, - "required": ["pattern"], -} - -LIST_PARAMETERS: FunctionParameters = { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Directory to list.", - }, - "ignore": { - "type": "array", - "items": {"type": "string"}, - "description": "Glob patterns to ignore.", - }, - }, -} - -class FilesystemTool(GroupedCapabilityMixin, AgentTool[OpenAICompatibleToolParam]): +class _FilesystemTool(GroupedCapabilityMixin, OpenAICompatibleTool): """Function tool backed by a HUD filesystem environment tool.""" description: ClassVar[str] @@ -104,54 +36,101 @@ def to_params(self) -> ChatCompletionToolParam: } -class ReadTool(FilesystemTool): +class ReadTool(_FilesystemTool): """Expose a read function over the environment read tool.""" name = "read" capability = "filesystem" env_tool_names = ("read",) description = "Reads a file from the local filesystem. Use offset and limit for pagination." - parameters: ClassVar[FunctionParameters] = READ_PARAMETERS + parameters: ClassVar[FunctionParameters] = { + "type": "object", + "properties": { + "filePath": { + "type": "string", + "description": "Absolute path to the file to read.", + }, + "offset": { + "type": "integer", + "description": "0-based line offset to start reading from.", + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read.", + }, + }, + "required": ["filePath"], + } -class GrepTool(FilesystemTool): +class GrepTool(_FilesystemTool): """Expose a grep function over the environment grep tool.""" name = "grep" capability = "filesystem" env_tool_names = ("grep",) description = "Searches file contents using a regular expression and returns matching lines." - parameters: ClassVar[FunctionParameters] = GREP_PARAMETERS + parameters: ClassVar[FunctionParameters] = { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Regular expression pattern to search for.", + }, + "path": { + "type": "string", + "description": "Directory to search in.", + }, + "include": { + "type": "string", + "description": "Glob pattern for files to include.", + }, + }, + "required": ["pattern"], + } -class GlobTool(FilesystemTool): +class GlobTool(_FilesystemTool): """Expose a glob function over the environment glob tool.""" name = "glob" capability = "filesystem" env_tool_names = ("glob",) description = "Finds files matching a glob pattern." - parameters: ClassVar[FunctionParameters] = GLOB_PARAMETERS + parameters: ClassVar[FunctionParameters] = { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Glob pattern to match.", + }, + "path": { + "type": "string", + "description": "Directory to search from.", + }, + }, + "required": ["pattern"], + } -class ListTool(FilesystemTool): +class ListTool(_FilesystemTool): """Expose a list function over the environment list tool.""" name = "list" capability = "filesystem" env_tool_names = ("list",) description = "Lists files and directories in a given path." - parameters: ClassVar[FunctionParameters] = LIST_PARAMETERS - - -__all__ = [ - "GLOB_PARAMETERS", - "GREP_PARAMETERS", - "LIST_PARAMETERS", - "READ_PARAMETERS", - "FilesystemTool", - "GlobTool", - "GrepTool", - "ListTool", - "ReadTool", -] + parameters: ClassVar[FunctionParameters] = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory to list.", + }, + "ignore": { + "type": "array", + "items": {"type": "string"}, + "description": "Glob patterns to ignore.", + }, + }, + } diff --git a/hud/agents/openai_compatible/tools/glm_computer.py b/hud/agents/openai_compatible/tools/glm_computer.py new file mode 100644 index 000000000..463860a19 --- /dev/null +++ b/hud/agents/openai_compatible/tools/glm_computer.py @@ -0,0 +1,294 @@ +"""Agent-side GLM computer tool for OpenAI-compatible chat models.""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING, Any, Literal, cast, get_args + +from hud.agents.tools import AgentToolSpec +from hud.agents.tools.computer import ( + computer_error_result, + computer_tool_info, + execute_computer_calls, +) + +from .base import OpenAICompatibleTool +from .settings import openai_compatible_tool_settings + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionToolParam + from openai.types.shared_params.function_parameters import FunctionParameters + + from hud.agents.tools import EnvironmentCapability + from hud.agents.tools.base import CallTool + from hud.types import MCPToolResult + +logger = logging.getLogger(__name__) + +GLM_COORDINATE_SPACE = 999 + +GLMAction = Literal[ + "left_click", + "click", + "right_click", + "middle_click", + "hover", + "left_double_click", + "left_drag", + "key", + "type", + "scroll", + "screenshot", + "WAIT", +] + +VALID_GLM_ACTIONS: set[str] = set(get_args(GLMAction)) + +GLM_COMPUTER_SPEC = AgentToolSpec( + api_type="function", + api_name="computer", + supported_models=("glm-*",), +) + +GLM_SYSTEM_INSTRUCTIONS = ( + "You are a GUI Agent. Your task is to respond accurately to user requests by using " + "tools or performing GUI operations until the task is fulfilled. Coordinates are in " + "thousandths (0-999). Complete tasks autonomously without asking for confirmation. " + "If a task cannot be completed, explain the failure in your final response." +) + +GLM_COMPUTER_DESCRIPTION = """\ +Use this tool to interact with the computer via GLM's PC action space. +* Coordinates use a 0-999 normalized scale (thousandths of screen dimensions). +* Always use valid JSON for function arguments. Do NOT use XML tags. + Correct: {"action": "left_click", "start_box": "[500, 300]"} + Wrong: {"action": "left_clickstart_box..."} +* Available actions: + - left_click/right_click/middle_click(start_box='[x,y]') + - hover(start_box='[x,y]'), left_double_click(start_box='[x,y]') + - left_drag(start_box='[x,y]', end_box='[x,y]') + - key(keys='ctrl+c'), type(content='text') + - scroll(start_box='[x,y]', direction='up|down', step=5) + - screenshot(), WAIT() +* If a task cannot be completed, explain the failure in your final response.\ +""".strip() + +GLM_COMPUTER_PARAMETERS: FunctionParameters = { + "type": "object", + "properties": { + "action": { + "type": "string", + "description": ( + "REQUIRED. Action to perform: left_click, right_click, middle_click, " + "hover, left_double_click, left_drag, key, type, scroll, screenshot, " + "WAIT" + ), + "enum": sorted(VALID_GLM_ACTIONS), + }, + "start_box": { + "description": ( + "Position as '[x,y]' string or [x,y] array, coordinates 0-999 normalized" + ), + }, + "end_box": { + "description": "End position for drag as '[x,y]' string or [x,y] array", + }, + "content": {"type": "string", "description": "Text content to type"}, + "keys": {"description": "Key(s) to press, e.g. 'enter', 'ctrl+c', 'alt+tab'"}, + "direction": {"type": "string", "description": "Scroll direction: 'up' or 'down'"}, + "step": {"type": "integer", "description": "Scroll steps", "default": 5}, + "element_info": {"type": "string", "description": "Optional UI element description"}, + }, + "required": ["action"], +} + + +class GLMComputerTool(OpenAICompatibleTool): + """Translate GLM native GUI calls into generic environment computer calls.""" + + name = "computer" + capability = "computer" + + @classmethod + def default_spec(cls, model: str) -> AgentToolSpec | None: + if GLM_COMPUTER_SPEC.supports_model(model): + return GLM_COMPUTER_SPEC + return None + + def __init__( + self, + *, + env_tool_name: str, + spec: AgentToolSpec, + display_width: int, + display_height: int, + coordinate_space: int | None, + ) -> None: + super().__init__(env_tool_name=env_tool_name, spec=spec) + self.display_width = display_width + self.display_height = display_height + self.coordinate_space = coordinate_space + + @classmethod + def from_capability( + cls, + capability: EnvironmentCapability, + model: str, + ) -> GLMComputerTool | None: + spec = cls.default_spec(model) + if spec is None: + return None + + computer_info = computer_tool_info( + capability.tool, + default_width=openai_compatible_tool_settings.GLM_COMPUTER_WIDTH, + default_height=openai_compatible_tool_settings.GLM_COMPUTER_HEIGHT, + ) + return cls( + env_tool_name=capability.tool_name, + spec=spec, + display_width=computer_info.display_width, + display_height=computer_info.display_height, + coordinate_space=computer_info.coordinate_space, + ) + + def to_params(self) -> ChatCompletionToolParam: + return { + "type": "function", + "function": { + "name": self.name, + "description": ( + f"{GLM_COMPUTER_DESCRIPTION}\n* The screen's resolution is " + f"{self.display_width}x{self.display_height}." + ), + "parameters": GLM_COMPUTER_PARAMETERS, + }, + } + + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + arguments = _normalize_glm_args(arguments) + action = arguments.get("action") + if not isinstance(action, str): + return computer_error_result("'action' is required") + + return await execute_computer_calls( + call_tool, + env_tool_name=self.env_tool_name, + calls=self._env_calls(action, arguments), + ensure_screenshot=action not in {"screenshot", "WAIT"}, + ) + + def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: + start = _parse_glm_box(arguments.get("start_box")) + end = _parse_glm_box(arguments.get("end_box")) + + if action == "screenshot": + return [{"action": "screenshot"}] + if action == "WAIT": + return [{"action": "wait", "time": 5000}] + if action in ("left_click", "click", "right_click", "middle_click"): + x, y = self._point(start, f"start_box required for {action}") + button = { + "left_click": "left", + "click": "left", + "right_click": "right", + "middle_click": "middle", + }[action] + return [{"action": "click", "x": x, "y": y, "button": button}] + if action == "hover": + x, y = self._point(start, "start_box required for hover") + return [{"action": "move", "x": x, "y": y}] + if action == "left_double_click": + x, y = self._point(start, "start_box required for left_double_click") + return [{"action": "click", "x": x, "y": y, "button": "left", "pattern": [100]}] + if action == "left_drag": + start_x, start_y = self._point(start, "start_box required for left_drag") + end_x, end_y = self._point(end, "end_box required for left_drag") + return [ + { + "action": "drag", + "path": [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}], + } + ] + if action == "key": + raw_keys = arguments.get("keys") + if isinstance(raw_keys, list): + keys = [str(key).strip().lower() for key in cast("list[Any]", raw_keys)] + else: + keys = [ + key.strip().lower() for key in str(raw_keys or "").split("+") if key.strip() + ] + if not keys: + raise ValueError("keys required for key action") + return [{"action": "press", "keys": keys}] + if action == "type": + content = arguments.get("content") + if not isinstance(content, str) or not content: + raise ValueError("content required for type") + return [{"action": "write", "text": content, "enter_after": False}] + if action == "scroll": + direction = arguments.get("direction") + if direction not in {"up", "down"}: + raise ValueError("direction must be 'up' or 'down'") + point = start or (GLM_COORDINATE_SPACE // 2, GLM_COORDINATE_SPACE // 2) + x, y = self._scale_normalized_point(point) + step = arguments.get("step") or 5 + scroll_y = int(step) * 100 if direction == "down" else -int(step) * 100 + return [{"action": "scroll", "x": x, "y": y, "scroll_y": scroll_y}] + raise ValueError(f"Unknown action: {action}") + + def _point(self, point: tuple[int, int] | None, message: str) -> tuple[int, int]: + if point is None: + raise ValueError(message) + return self._scale_normalized_point(point) + + def _scale_normalized_point(self, point: tuple[int, int]) -> tuple[int, int]: + if self.coordinate_space == GLM_COORDINATE_SPACE: + return point + x, y = point + scaled_x = round(x / GLM_COORDINATE_SPACE * (self.display_width - 1)) + scaled_y = round(y / GLM_COORDINATE_SPACE * (self.display_height - 1)) + return scaled_x, scaled_y + + +def _parse_glm_box(box: Any) -> tuple[int, int] | None: + if box is None: + return None + if isinstance(box, str): + match = re.match(r"\[?\s*(\d+)\s*,\s*(\d+)\s*\]?", box.strip()) + if match: + return int(match.group(1)), int(match.group(2)) + return None + if isinstance(box, list): + nested = cast("list[Any]", box) + if len(nested) == 1 and isinstance(nested[0], list): + nested = cast("list[Any]", nested[0]) + if len(nested) >= 2: + try: + return int(nested[0]), int(nested[1]) + except (TypeError, ValueError): + return None + return None + + +def _normalize_glm_args(args: dict[str, Any]) -> dict[str, Any]: + fixed: dict[str, Any] = {} + for key, value in args.items(): + if not isinstance(value, str) or not re.search(r"(\w+)\s*([^\"<]+)", value) + for arg_name, arg_val in matches: + if arg_name and arg_val: + fixed[arg_name.strip()] = arg_val.strip() + + if not main_value and not matches: + fixed[key] = value + logger.warning("Fixed GLM XML args: %s -> %s", args, fixed) + return fixed diff --git a/hud/agents/openai_compatible/tools/qwen_computer.py b/hud/agents/openai_compatible/tools/qwen_computer.py new file mode 100644 index 000000000..425e5f844 --- /dev/null +++ b/hud/agents/openai_compatible/tools/qwen_computer.py @@ -0,0 +1,266 @@ +"""Agent-side Qwen computer tool for OpenAI-compatible chat models.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast + +from hud.agents.tools import AgentToolSpec +from hud.agents.tools.computer import ( + computer_error_result, + computer_tool_info, + execute_computer_calls, +) + +from .base import OpenAICompatibleTool +from .settings import openai_compatible_tool_settings + +if TYPE_CHECKING: + from openai.types.shared_params.function_parameters import FunctionParameters + + from hud.agents.tools import EnvironmentCapability + from hud.agents.tools.base import CallTool + from hud.types import MCPToolResult + +QWEN_COMPUTER_SPEC = AgentToolSpec( + api_type="computer_use", + api_name="computer_use", + supported_models=("qwen*",), +) + + +class QwenComputerUseToolParam(TypedDict): + """Qwen's OpenAI-compatible computer_use extension.""" + + type: Literal["computer_use"] + name: str + display_width_px: int + display_height_px: int + description: str + parameters: FunctionParameters + + +class QwenComputerTool(OpenAICompatibleTool): + """Translate Qwen computer_use calls into generic environment computer calls.""" + + name = "computer_use" + capability = "computer" + + @classmethod + def default_spec(cls, model: str) -> AgentToolSpec | None: + if QWEN_COMPUTER_SPEC.supports_model(model): + return QWEN_COMPUTER_SPEC + return None + + def __init__( + self, + *, + env_tool_name: str, + spec: AgentToolSpec, + display_width: int, + display_height: int, + description: str, + ) -> None: + super().__init__(env_tool_name=env_tool_name, spec=spec) + self.display_width = display_width + self.display_height = display_height + self.description = description + + @classmethod + def from_capability( + cls, + capability: EnvironmentCapability, + model: str, + ) -> QwenComputerTool | None: + spec = cls.default_spec(model) + if spec is None: + return None + + computer_info = computer_tool_info( + capability.tool, + default_width=openai_compatible_tool_settings.QWEN_COMPUTER_WIDTH, + default_height=openai_compatible_tool_settings.QWEN_COMPUTER_HEIGHT, + ) + return cls( + env_tool_name=capability.tool_name, + spec=spec, + display_width=computer_info.display_width, + display_height=computer_info.display_height, + description=_qwen_description( + computer_info.display_width, computer_info.display_height + ), + ) + + def to_params(self) -> QwenComputerUseToolParam: + tool: QwenComputerUseToolParam = { + "type": "computer_use", + "name": self.name, + "display_width_px": self.display_width, + "display_height_px": self.display_height, + "description": self.description, + "parameters": QWEN_COMPUTER_PARAMETERS, + } + return tool + + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + action = arguments.get("action") + if not isinstance(action, str): + return computer_error_result("action is required") + if action == "terminate": + return computer_error_result("terminate action is not supported for computer control.") + if action == "answer": + return computer_error_result("answer action is not supported for computer control.") + + return await execute_computer_calls( + call_tool, + env_tool_name=self.env_tool_name, + calls=self._env_calls(action, arguments), + ensure_screenshot=action not in {"screenshot", "wait"}, + ) + + def _env_calls(self, action: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: + coordinate = _parse_qwen_coordinate(arguments.get("coordinate")) + if action == "screenshot": + return [{"action": "screenshot"}] + if action in {"left_click", "right_click", "middle_click"}: + x, y = _required_coordinate(coordinate, action) + button = {"left_click": "left", "right_click": "right", "middle_click": "middle"}[ + action + ] + return [{"action": "click", "x": x, "y": y, "button": button}] + if action == "double_click": + x, y = _required_coordinate(coordinate, action) + return [{"action": "click", "x": x, "y": y, "pattern": [100]}] + if action == "triple_click": + x, y = _required_coordinate(coordinate, action) + return [{"action": "click", "x": x, "y": y, "pattern": [100, 100]}] + if action == "mouse_move": + x, y = _required_coordinate(coordinate, action) + return [{"action": "move", "x": x, "y": y}] + if action == "type": + text = arguments.get("text") + if not isinstance(text, str): + raise ValueError("text is required for type") + return [{"action": "write", "text": text}] + if action == "key": + keys = arguments.get("keys") + if not isinstance(keys, list): + raise ValueError("keys is required for key") + return [{"action": "press", "keys": keys}] + if action in {"scroll", "hscroll"}: + pixels = arguments.get("pixels") + if not isinstance(pixels, int | float): + raise ValueError("pixels is required for scroll") + call: dict[str, Any] = {"action": "scroll"} + if coordinate is not None: + call.update({"x": coordinate[0], "y": coordinate[1]}) + if action == "scroll": + call["scroll_y"] = -int(pixels) + else: + call["scroll_x"] = int(pixels) + return [call] + if action == "left_click_drag": + x, y = _required_coordinate(coordinate, action) + return [ + {"action": "mouse_down", "button": "left"}, + {"action": "move", "x": x, "y": y}, + {"action": "mouse_up", "button": "left"}, + ] + if action == "wait": + time = arguments.get("time") + if not isinstance(time, int | float): + raise ValueError("time is required for wait") + if time < 0: + raise ValueError("time must be non-negative") + return [{"action": "wait", "time": int(time * 1000)}] + raise ValueError(f"Invalid action: {action}") + + +QWEN_COMPUTER_PARAMETERS: FunctionParameters = { + "properties": { + "action": { + "description": """ +The action to perform. The available actions are: +* `key`: Performs key down presses on the arguments passed in order, then performs +key releases in reverse order. +* `type`: Type a string of text on the keyboard. +* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen. +* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate. +* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate. +* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate. +* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate. +* `double_click`: Double-click the left mouse button. +* `triple_click`: Triple-click the left mouse button. +* `scroll`: Performs a vertical scroll. +* `hscroll`: Performs a horizontal scroll. +* `wait`: Wait specified seconds for the change to happen. +""".strip(), + "enum": [ + "key", + "type", + "mouse_move", + "left_click", + "left_click_drag", + "right_click", + "middle_click", + "double_click", + "triple_click", + "scroll", + "hscroll", + "wait", + ], + "type": "string", + }, + "keys": {"description": "Required only by `action=key`.", "type": "array"}, + "text": { + "description": "Required only by `action=type`.", + "type": "string", + }, + "coordinate": { + "description": "(x, y) pixel coordinate to interact with.", + "type": "array", + }, + "pixels": { + "description": "Scroll amount. Positive vertical values scroll up.", + "type": "number", + }, + "time": { + "description": "Seconds to wait. Required only by `action=wait`.", + "type": "number", + }, + }, + "required": ["action"], + "type": "object", +} + + +def _qwen_description(width: int, height: int) -> str: + return f""" +Use a mouse and keyboard to interact with a computer, and take screenshots. +* This is an interface to a desktop GUI. You do not have access to a terminal or +applications menu. You must click on desktop icons to start applications. +* Some applications may take time to start or process actions, so you may need to +wait and take successive screenshots to see the results of your actions. +* The screen's resolution is {width}x{height}. +* Whenever you intend to move the cursor to click on an element like an icon, you +should consult a screenshot to determine the coordinates of the element before +moving the cursor. +* Make sure to click buttons, links, and icons with the cursor tip in the center. +""".strip() + + +def _parse_qwen_coordinate(coordinate: Any) -> tuple[int, int] | None: + if not isinstance(coordinate, list | tuple): + return None + coord = cast("list[Any] | tuple[Any, ...]", coordinate) + if len(coord) < 2: + return None + try: + return int(coord[0]), int(coord[1]) + except (TypeError, ValueError): + return None + + +def _required_coordinate(coordinate: tuple[int, int] | None, action: str) -> tuple[int, int]: + if coordinate is None: + raise ValueError(f"coordinate is required for {action}") + return coordinate diff --git a/hud/agents/openai_compatible/tools/settings.py b/hud/agents/openai_compatible/tools/settings.py new file mode 100644 index 000000000..8ec3dbe71 --- /dev/null +++ b/hud/agents/openai_compatible/tools/settings.py @@ -0,0 +1,36 @@ +"""OpenAI-compatible native tool settings owned by the agent.""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class OpenAICompatibleToolSettings(BaseSettings): + """Provider defaults for OpenAI-compatible agent-owned native tools.""" + + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="allow") + + GLM_COMPUTER_WIDTH: int = Field( + default=1024, + description="Default GLM computer-use display width", + validation_alias="GLM_COMPUTER_WIDTH", + ) + GLM_COMPUTER_HEIGHT: int = Field( + default=768, + description="Default GLM computer-use display height", + validation_alias="GLM_COMPUTER_HEIGHT", + ) + QWEN_COMPUTER_WIDTH: int = Field( + default=700, + description="Default Qwen computer-use display width", + validation_alias="QWEN_COMPUTER_WIDTH", + ) + QWEN_COMPUTER_HEIGHT: int = Field( + default=448, + description="Default Qwen computer-use display height", + validation_alias="QWEN_COMPUTER_HEIGHT", + ) + + +openai_compatible_tool_settings = OpenAICompatibleToolSettings() diff --git a/hud/agents/openai_compatible/tools/types.py b/hud/agents/openai_compatible/tools/types.py deleted file mode 100644 index 2bded858a..000000000 --- a/hud/agents/openai_compatible/tools/types.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Type definitions for OpenAI-compatible chat tools.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, TypeAlias, TypedDict - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionToolParam - from openai.types.shared_params.function_parameters import FunctionParameters - - -class QwenComputerUseToolParam(TypedDict): - """Qwen's OpenAI-compatible computer_use extension.""" - - type: Literal["computer_use"] - name: str - display_width_px: int - display_height_px: int - description: str - parameters: FunctionParameters - - -OpenAICompatibleToolParam: TypeAlias = "ChatCompletionToolParam | QwenComputerUseToolParam" - - -__all__ = ["OpenAICompatibleToolParam", "QwenComputerUseToolParam"] diff --git a/hud/agents/resolver.py b/hud/agents/resolver.py deleted file mode 100644 index ae9bd8b89..000000000 --- a/hud/agents/resolver.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Model resolution - maps model strings to agent classes.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from hud.agents.base import MCPAgent - -__all__ = ["resolve_cls"] - -_models_cache: list[dict[str, Any]] | None = None - - -def _fetch_gateway_models() -> list[dict[str, Any]]: - """Fetch available models from HUD API (cached).""" - global _models_cache - if _models_cache is not None: - return _models_cache - - import httpx - - from hud.settings import settings - - if not settings.api_key: - return [] - - try: - resp = httpx.get( - f"{settings.hud_api_url}/models/", - headers={"Authorization": f"Bearer {settings.api_key}"}, - timeout=10.0, - ) - resp.raise_for_status() - data = resp.json() - models = data.get("models") or [] - _models_cache = models - return models - except Exception: - return [] - - -def resolve_cls(model: str) -> tuple[type[MCPAgent], dict[str, Any] | None]: - """Resolve model string to (agent_class, gateway_info). - - Returns: - (agent_class, None) for known AgentTypes - (agent_class, gateway_model_info) for gateway models - """ - from hud.types import AgentType - - # Known AgentType → no gateway info - try: - return AgentType(model).cls, None - except ValueError: - pass - - # Gateway lookup - for m in _fetch_gateway_models(): - if model in (m.get("id"), m.get("name"), m.get("model_name")): - agent_str = m.get("sdk_agent_type") or m["provider"]["default_sdk_agent_type"] - if agent_str == "operator": - raise ValueError( - "Operator agent is no longer supported; use openai with a supported " - "OpenAI computer model." - ) - if agent_str == "gemini_cua": - raise ValueError( - "Gemini CUA agent is no longer supported; use gemini with a supported " - "Gemini computer-use model." - ) - return AgentType(agent_str).cls, m - - raise ValueError(f"Model '{model}' not found") diff --git a/hud/agents/tests/conftest.py b/hud/agents/tests/conftest.py index eb4880f4b..2bfd37b0b 100644 --- a/hud/agents/tests/conftest.py +++ b/hud/agents/tests/conftest.py @@ -1,42 +1,218 @@ -"""Shared test fixtures for agent tests.""" +# pyright: reportPrivateUsage=false +"""Shared behavioral harness for agent tests.""" from __future__ import annotations -from typing import Any +from functools import cached_property +from typing import TYPE_CHECKING, Any, ClassVar, cast import pytest from mcp import types +from hud.agents.base import MCPAgent +from hud.agents.tools import ( + AgentTool, + AgentTools, + AgentToolSpec, + GroupedCapabilityMixin, + ToolMetadata, +) +from hud.agents.tools.base import ToolClient +from hud.agents.types import AgentConfig from hud.environment.router import ToolRouter +from hud.environment.scenarios import ScenarioSession from hud.eval.context import EvalContext -from hud.types import MCPToolCall, MCPToolResult +from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace +if TYPE_CHECKING: + from collections.abc import Callable, Mapping -class MockEvalContext(EvalContext): - """Mock EvalContext for testing agents. - This provides a minimal EvalContext implementation that can be used - to test agent initialization and tool calling without a real environment. - """ +class HarnessConfig(AgentConfig): + model_name: str = "HarnessAgent" + model: str = "harness-model" + + +def mcp_tool(name: str, *, description: str | None = None) -> types.Tool: + return types.Tool( + name=name, + description=description or f"{name} tool", + inputSchema={"type": "object", "properties": {}}, + ) + + +def text_prompt(text: str, *, role: types.Role = "user") -> types.PromptMessage: + return types.PromptMessage( + role=role, + content=types.TextContent(type="text", text=text), + ) + + +def text_result(text: str, *, is_error: bool = False) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text=text)], + isError=is_error, + ) + + +def result_text(result: MCPToolResult) -> str: + return "\n".join(block.text for block in result.content if isinstance(block, types.TextContent)) + + +class HarnessTool(AgentTool[dict[str, Any]]): + name = "function" + capability = "function" + + @classmethod + def from_tool(cls, tool: types.Tool) -> HarnessTool: + return cls( + env_tool_name=tool.name, + spec=AgentToolSpec(api_type="function", api_name=tool.name), + ) + + @property + def provider_name(self) -> str: + return self.env_tool_name + + def to_params(self) -> dict[str, Any]: + return {"name": self.provider_name} + + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> dict[str, Any]: + return { + "role": "tool", + "name": call.name, + "content": result_text(result), + "is_error": result.isError, + } + + +class HarnessTools(AgentTools[HarnessTool, dict[str, Any]]): + function_tool_class = HarnessTool + + +class HarnessNativeShellTool(HarnessTool): + name = "shell" + capability = "shell" + + @property + def provider_name(self) -> str: + return self.name + + @classmethod + def default_spec(cls, model: str) -> AgentToolSpec: + del model + return AgentToolSpec(api_type="shell", api_name="shell") + + +class HarnessFilesystemReadTool(GroupedCapabilityMixin, HarnessTool): + name = "read_file" + capability = "filesystem" + env_tool_names: ClassVar[tuple[str, ...]] = ("read", "read_file") + + @property + def provider_name(self) -> str: + return self.name + + @classmethod + def default_spec(cls, model: str) -> AgentToolSpec: + del model + return AgentToolSpec(api_type="function", api_name="read_file") + + +class RoutingHarnessTools(AgentTools[HarnessTool, dict[str, Any]]): + native_tool_classes = (HarnessNativeShellTool, HarnessFilesystemReadTool) + function_tool_class = HarnessTool + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {"shell": ("bash",)} + + +class ScriptedAgent(MCPAgent[dict[str, Any]]): + """Agent fake that exercises the real `MCPAgent.run` loop.""" + + def __init__( + self, + responses: list[AgentResponse | BaseException], + *, + config: HarnessConfig | None = None, + tools_factory: Callable[[], AgentTools[Any, Any]] | None = None, + ) -> None: + super().__init__(config or HarnessConfig()) + self.config: HarnessConfig + self.responses = list(responses) + self.seen_messages: list[list[dict[str, Any]]] = [] + self._tools_factory = tools_factory or HarnessTools + + @cached_property + def tools(self) -> AgentTools[Any, Any]: + return self._tools_factory() + + async def format_messages(self, messages: list[types.PromptMessage]) -> list[dict[str, Any]]: + formatted: list[dict[str, Any]] = [] + for message in messages: + content = message.content + formatted.append( + { + "role": message.role, + "content": content.text if isinstance(content, types.TextContent) else "", + } + ) + return formatted + + async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: + self.seen_messages.append([dict(message) for message in messages]) + response = self.responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class RecordingToolEnvironment: + """Records the environment-facing MCP calls made by an agent run.""" + + def __init__( + self, + tools: list[types.Tool] | None = None, + *, + results: Mapping[str, MCPToolResult | Exception] | None = None, + tool_metadata: ToolMetadata | None = None, + ) -> None: + self.tools = tools or [] + self.results = dict(results or {}) + self.tool_metadata = tool_metadata + self.calls: list[MCPToolCall] = [] + + @property + def client(self) -> ToolClient: + return ToolClient( + tools=self.tools, + tool_handler=self.call_tool, + tool_metadata=self.tool_metadata, + ) + + async def call_tool(self, call: MCPToolCall) -> MCPToolResult: + self.calls.append(call) + result = self.results.get(call.name, text_result(f"result from {call.name}")) + if isinstance(result, Exception): + raise result + return result + + +class HarnessEvalContext(EvalContext): + """Small EvalContext double that keeps the real `_run` and prompt behavior.""" def __init__( self, prompt: str = "Test prompt", + *, tools: list[types.Tool] | None = None, - call_tool_handler: Any = None, + tool_results: Mapping[str, MCPToolResult | Exception] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: - # Core attributes self.prompt = prompt - self._tools = tools or [] + self.environment = RecordingToolEnvironment(tools or [], results=tool_results) self._submitted: str | dict[str, Any] | None = None self.reward: float | None = None - self._call_tool_handler = call_tool_handler - self.tool_calls: list[tuple[str, dict[str, Any]]] = [] - - # Environment attributes self._router = ToolRouter() - - # EvalContext attributes + self._scenario_sessions = {} self._task = None self.trace_id = "test-trace-id" self.eval_name = "test-eval" @@ -47,85 +223,61 @@ def __init__( self.answer: str | dict[str, Any] | None = None self.system_prompt: str | None = None self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} + self.metadata = metadata or {} self.results: list[Any] = [] self._is_summary = False + self._eval_api_key: str | None = None + self._trace_enabled = False def as_tools(self) -> list[types.Tool]: - return self._tools + return self.environment.tools @property - def has_scenario(self) -> bool: - return False + def submitted(self) -> str | dict[str, Any] | None: + return self._submitted - async def list_tools(self) -> list[types.Tool]: - return self._tools + def set_scenario_messages(self, messages: list[types.PromptMessage]) -> None: + self._scenario_sessions["__client__"] = ScenarioSession( + local_name="chat", + full_name="test-env:chat", + is_local=True, + connection_name=None, + resource_uri="test-env:chat", + prompt_messages=messages, + ) - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - # Parse the call - if isinstance(call, tuple): - name, args = call[0], call[1] if len(call) > 1 else {} - elif hasattr(call, "name"): - name, args = call.name, getattr(call, "arguments", {}) or {} - else: - name, args = str(call), kwargs + def tool_metadata_for_run(self) -> ToolMetadata | None: + return self._tool_metadata() - self.tool_calls.append((name, args)) + async def run_agent(self, agent: Any, *, max_steps: int = 10) -> Trace: + return await self._run(agent, max_steps=max_steps) - if self._call_tool_handler: - tc = MCPToolCall(name=name, arguments=args) - return self._call_tool_handler(tc) + async def list_tools(self, **kwargs: Any) -> list[types.Tool]: + del kwargs + return self.environment.tools - return MCPToolResult( - content=[types.TextContent(type="text", text=f"Result from {name}")], - isError=False, - ) + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + if isinstance(call, MCPToolCall): + tool_call = call + elif isinstance(call, tuple): + call_tuple = cast("tuple[Any, ...]", call) + tool_call = MCPToolCall( + name=str(call_tuple[0]), + arguments=cast("dict[str, Any]", call_tuple[1] if len(call_tuple) > 1 else {}), + ) + else: + tool_call = MCPToolCall(name=str(call), arguments=kwargs) + return await self.environment.call_tool(tool_call) async def submit(self, answer: str | dict[str, Any]) -> None: self._submitted = answer @pytest.fixture -def mock_eval_context() -> MockEvalContext: - """Create a basic mock EvalContext.""" - return MockEvalContext() +def basic_tool() -> types.Tool: + return mcp_tool("lookup") @pytest.fixture -def mock_eval_context_with_tools() -> MockEvalContext: - """Create a mock EvalContext with test tools.""" - return MockEvalContext( - tools=[ - types.Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ) - ] - ) - - -@pytest.fixture -def mock_eval_context_computer() -> MockEvalContext: - """Create a mock EvalContext with computer tool.""" - return MockEvalContext( - tools=[ - types.Tool( - name="computer", - description="Computer use tool", - inputSchema={"type": "object"}, - ) - ] - ) - - -@pytest.fixture -def mock_eval_context_browser_tools() -> MockEvalContext: - """Create a mock EvalContext with browser-like tools.""" - return MockEvalContext( - tools=[ - types.Tool(name="screenshot", description="Take screenshot", inputSchema={}), - types.Tool(name="click", description="Click at coordinates", inputSchema={}), - types.Tool(name="type", description="Type text", inputSchema={}), - ] - ) +def recording_environment(basic_tool: types.Tool) -> RecordingToolEnvironment: + return RecordingToolEnvironment([basic_tool]) diff --git a/hud/agents/tests/test_base.py b/hud/agents/tests/test_base.py deleted file mode 100644 index ef6fa7d0f..000000000 --- a/hud/agents/tests/test_base.py +++ /dev/null @@ -1,537 +0,0 @@ -"""Tests for MCPAgent base class with the EvalContext pattern.""" - -from __future__ import annotations - -from typing import Any, ClassVar - -import pytest -from mcp import types - -from hud.agents import MCPAgent -from hud.agents.base import BaseCreateParams -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult - - -class MockConfig(BaseAgentConfig): - model_name: str = "MockAgent" - model: str = "mock-model" - - -class MockCreateParams(BaseCreateParams, MockConfig): - pass - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing.""" - - def __init__( - self, - prompt: str = "Test prompt", - tools: list[types.Tool] | None = None, - ) -> None: - # Core attributes - self.prompt = prompt - self._tools = tools or [ - types.Tool(name="test_tool", description="A test tool", inputSchema={}), - types.Tool(name="another_tool", description="Another tool", inputSchema={}), - ] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - self._tool_calls: list[tuple[str, dict[str, Any]]] = [] - - # Environment attributes - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} - self.results: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return True - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - # Parse the call - if isinstance(call, tuple): - name, args = call[0], call[1] if len(call) > 1 else {} - elif hasattr(call, "name"): - name, args = call.name, getattr(call, "arguments", {}) or {} - else: - name, args = str(call), kwargs - self._tool_calls.append((name, args)) - return MCPToolResult( - content=[types.TextContent(type="text", text=f"Result from {name}")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class MockMCPAgent(MCPAgent): - """Concrete implementation of MCPAgent for testing.""" - - metadata: ClassVar[dict[str, Any] | None] = {} - config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for the mock agent.""" - return AgentType.OPENAI - - def __init__(self, **kwargs: Any) -> None: - params = MockCreateParams(**kwargs) - super().__init__(params) - self._response = InferenceResult(content="Mock response", tool_calls=[], done=True) - - def set_response(self, response: InferenceResult) -> None: - self._response = response - - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: - return self._response - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[dict[str, Any]]: - formatted = [] - for tool_call, result in zip(tool_calls, tool_results, strict=True): - formatted.append({"role": "tool", "name": tool_call.name, "content": str(result)}) - return formatted - - async def get_system_messages(self) -> list[Any]: - return [] - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - return [{"type": "text", "text": getattr(b, "text", "")} for b in blocks] - - -class TestMCPAgentInit: - """Tests for MCPAgent initialization.""" - - def test_init_defaults(self) -> None: - """Test agent initializes with default config.""" - agent = MockMCPAgent() - assert agent.ctx is None - assert agent._initialized is False - assert agent.system_prompt is None - - def test_init_with_system_prompt(self) -> None: - """Test agent with custom system prompt.""" - agent = MockMCPAgent(system_prompt="Custom prompt") - assert agent.system_prompt == "Custom prompt" - - -class TestMCPAgentRun: - """Tests for MCPAgent.run() with EvalContext.""" - - @pytest.mark.asyncio - async def test_run_basic(self) -> None: - """Test basic run flow with EvalContext.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - - result = await agent.run(ctx) - - assert result.done is True - assert result.content == "Mock response" - assert ctx._submitted == "Mock response" - - @pytest.mark.asyncio - async def test_run_initializes_agent(self) -> None: - """Test run() initializes the agent with context.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - - assert not agent._initialized - await agent.run(ctx) - assert agent._initialized - - @pytest.mark.asyncio - async def test_run_discovers_tools(self) -> None: - """Test run() discovers tools from context.""" - tools = [ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ] - ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = MockMCPAgent() - - # We need to check tools before cleanup - # Store a reference to check - discovered_tools = [] - - original_run = agent._run_context - - async def capture_tools(*args: Any, **kwargs: Any) -> Any: - discovered_tools.extend(agent.get_available_tools()) - return await original_run(*args, **kwargs) - - agent._run_context = capture_tools # type: ignore - await agent.run(ctx) - - assert len(discovered_tools) == 2 - assert discovered_tools[0].name == "tool1" - assert discovered_tools[1].name == "tool2" - - @pytest.mark.asyncio - async def test_run_requires_eval_context(self) -> None: - """Test run() raises TypeError for non-EvalContext.""" - agent = MockMCPAgent() - - with pytest.raises(TypeError, match="must be EvalContext"): - await agent.run("not a context") # type: ignore - - @pytest.mark.asyncio - async def test_run_requires_prompt(self) -> None: - """Test run() raises ValueError when prompt is empty.""" - ctx = MockEvalContext(prompt="") - agent = MockMCPAgent() - - with pytest.raises(ValueError, match="prompt is not set"): - await agent.run(ctx) - - @pytest.mark.asyncio - async def test_run_clears_context_after(self) -> None: - """Test run() clears ctx after completion.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - - await agent.run(ctx) - assert agent.ctx is None - - @pytest.mark.asyncio - async def test_run_no_submit_on_empty_content(self) -> None: - """Test run() doesn't submit when content is empty.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - agent.set_response(InferenceResult(content="", tool_calls=[], done=True)) - - await agent.run(ctx) - assert ctx._submitted is None - - -class TestMCPAgentToolCalling: - """Tests for tool calling through context.""" - - @pytest.mark.asyncio - async def test_call_tools_uses_context(self) -> None: - """Test call_tools routes through ctx.call_tool.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - - # Bind context manually - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Call a tool - results = await agent.call_tools(MCPToolCall(name="test_tool", arguments={"arg": "value"})) - - assert len(results) == 1 - assert not results[0].isError - assert ("test_tool", {"arg": "value"}) in ctx._tool_calls - - @pytest.mark.asyncio - async def test_call_tools_without_context_raises(self) -> None: - """Test call_tools raises when no context bound.""" - agent = MockMCPAgent() - - with pytest.raises(ValueError, match="not bound to context"): - await agent.call_tools(MCPToolCall(name="test_tool", arguments={})) - - -class TestMCPAgentRequiredTools: - """Tests for required_tools validation.""" - - @pytest.mark.asyncio - async def test_missing_required_tools_raises(self) -> None: - """Test run() raises when required tools are missing.""" - - class AgentWithRequiredTools(MockMCPAgent): - required_tools: ClassVar[list[str]] = ["must_have_tool"] - - ctx = MockEvalContext(prompt="Do something", tools=[]) - agent = AgentWithRequiredTools() - - with pytest.raises(ValueError, match="Required tools are missing"): - await agent.run(ctx) - - @pytest.mark.asyncio - async def test_required_tools_present_succeeds(self) -> None: - """Test run() succeeds when required tools are present.""" - - class AgentWithRequiredTools(MockMCPAgent): - required_tools: ClassVar[list[str]] = ["required_tool"] - - tools = [types.Tool(name="required_tool", description="Required", inputSchema={})] - ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = AgentWithRequiredTools() - - result = await agent.run(ctx) - assert result.done - - -class TestMCPAgentOnToolsReady: - """Tests for _on_tools_ready hook.""" - - @pytest.mark.asyncio - async def test_on_tools_ready_called(self) -> None: - """Test _on_tools_ready is called during initialization.""" - hook_called = [False] - - class AgentWithHook(MockMCPAgent): - def _on_tools_ready(self) -> None: - hook_called[0] = True - - ctx = MockEvalContext(prompt="Do something") - agent = AgentWithHook() - - await agent.run(ctx) - assert hook_called[0] - - @pytest.mark.asyncio - async def test_on_tools_ready_has_access_to_tools(self) -> None: - """Test _on_tools_ready can access discovered tools.""" - captured_tools: list[types.Tool] = [] - - class AgentWithHook(MockMCPAgent): - def _on_tools_ready(self) -> None: - captured_tools.extend(self.get_available_tools()) - - tools = [ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ] - ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = AgentWithHook() - - await agent.run(ctx) - - assert len(captured_tools) == 2 - assert captured_tools[0].name == "tool1" - - -class TestMCPAgentToolSchemas: - """Tests for tool schema generation.""" - - @pytest.mark.asyncio - async def test_get_tool_schemas(self) -> None: - """Test get_tool_schemas returns correct format.""" - tools = [ - types.Tool( - name="my_tool", - description="My tool description", - inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = MockMCPAgent() - - # Initialize agent - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - schemas = agent.get_tool_schemas() - assert len(schemas) == 1 - assert schemas[0]["name"] == "my_tool" - assert schemas[0]["description"] == "My tool description" - - -class TestMCPAgentErrorPropagation: - """Tests for error propagation to EvalContext.""" - - @pytest.mark.asyncio - async def test_exception_propagates_to_ctx_error(self) -> None: - """Test that exceptions during run() set ctx.error for platform visibility.""" - - class FailingAgent(MockMCPAgent): - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: - raise RuntimeError("Agent crashed") - - ctx = MockEvalContext(prompt="Do something") - agent = FailingAgent() - - result = await agent.run(ctx) - - # Should return error trace - assert result.isError is True - assert result.content is not None - assert "Agent crashed" in result.content - - assert ctx.error is not None - assert isinstance(ctx.error, BaseException) - assert "Agent crashed" in str(ctx.error) - - @pytest.mark.asyncio - async def test_step_error_propagates_to_ctx_error(self) -> None: - """Test that step-level errors (caught internally) set ctx.error.""" - step_count = [0] - - class FailOnSecondStepAgent(MockMCPAgent): - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: - step_count[0] += 1 - if step_count[0] == 1: - return InferenceResult( - content="", - tool_calls=[MCPToolCall(name="test_tool", arguments={})], - done=False, - ) - else: - raise ValueError("Step 2 failed") - - ctx = MockEvalContext(prompt="Do something") - agent = FailOnSecondStepAgent() - - result = await agent.run(ctx) - - # Should return error trace - assert result.isError is True - assert ctx.error is not None - assert "Step 2 failed" in str(ctx.error) - - @pytest.mark.asyncio - async def test_no_error_when_successful(self) -> None: - """Test that ctx.error remains None on successful run.""" - ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent() - - result = await agent.run(ctx) - - assert result.isError is False - assert ctx.error is None - - -class TestMCPAgentCategorizeTools: - """Tests for the categorize_tools method.""" - - @pytest.mark.asyncio - async def test_categorize_generic_tools(self) -> None: - """All MCP tools are generic unless a provider agent filters them.""" - tools = [ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ] - ctx = MockEvalContext(prompt="Test", tools=tools) - agent = MockMCPAgent() - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - categorized = agent.categorize_tools() - - assert len(categorized.generic) == 2 - assert len(categorized.skipped) == 0 - - @pytest.mark.asyncio - async def test_ignores_legacy_native_tool_metadata(self) -> None: - """Legacy native metadata no longer affects base categorization.""" - tool_with_metadata = types.Tool( - name="tool_with_metadata", - description="Tool with ignored metadata", - inputSchema={}, - _meta={ - "native_tools": { - "openai": { - "api_type": "test_type", - "role": "test_role", - } - } - }, - ) - tools = [tool_with_metadata] - - ctx = MockEvalContext(prompt="Test", tools=tools) - agent = MockMCPAgent() - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - categorized = agent.categorize_tools() - - assert len(categorized.generic) == 1 - assert categorized.generic[0].name == "tool_with_metadata" - assert len(categorized.skipped) == 0 - - @pytest.mark.asyncio - async def test_no_role_exclusion_from_legacy_metadata(self) -> None: - """Tool role metadata is not a control plane anymore.""" - first_tool = types.Tool( - name="claude_computer", - description="Claude computer", - inputSchema={}, - _meta={ - "native_tools": { - "openai": { - "api_type": "computer_test", - "role": "computer", - } - } - }, - ) - second_tool = types.Tool( - name="gemini_computer", - description="Gemini computer", - inputSchema={}, - _meta={ - "native_tools": { - "gemini": { - "api_type": "computer_use", - "role": "computer", - } - } - }, - ) - tools = [first_tool, second_tool] - - ctx = MockEvalContext(prompt="Test", tools=tools) - agent = MockMCPAgent() - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - categorized = agent.categorize_tools() - - assert [tool.name for tool in categorized.generic] == ["claude_computer", "gemini_computer"] - assert len(categorized.skipped) == 0 - - @pytest.mark.asyncio - async def test_hosted_metadata_stays_generic(self) -> None: - """Hosted tools are configured on agents, not environment metadata.""" - hosted_tool = types.Tool( - name="google_search", - description="Google Search", - inputSchema={}, - _meta={ - "native_tools": { - "openai": { - "api_type": "google_search", - "hosted": True, - } - } - }, - ) - tools = [hosted_tool] - - ctx = MockEvalContext(prompt="Test", tools=tools) - agent = MockMCPAgent() - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - categorized = agent.categorize_tools() - - assert [tool.name for tool in categorized.generic] == ["google_search"] diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py deleted file mode 100644 index 1a4eec41a..000000000 --- a/hud/agents/tests/test_base_runtime.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Runtime tests for MCPAgent base class.""" - -from __future__ import annotations - -from typing import Any - -import mcp.types as types -import pytest - -from hud.agents.base import BaseCreateParams, MCPAgent, text_to_blocks -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult - - -class DummyConfig(BaseAgentConfig): - model_name: str = "DummyAgent" - model: str = "dummy-model" - - -class DummyCreateParams(BaseCreateParams, DummyConfig): - pass - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing.""" - - def __init__( - self, - prompt: str = "Test prompt", - tools: list[types.Tool] | None = None, - ) -> None: - # Core attributes - self.prompt = prompt - self._tools = tools or [] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - self._call_tool_handler: Any = None - - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} - self.results: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return False - - def set_call_tool_handler(self, handler: Any) -> None: - self._call_tool_handler = handler - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - if self._call_tool_handler: - # Parse the call - if isinstance(call, tuple): - tc = MCPToolCall(name=call[0], arguments=call[1] if len(call) > 1 else {}) - elif hasattr(call, "name"): - tc = call - else: - tc = MCPToolCall(name=str(call), arguments=kwargs) - return self._call_tool_handler(tc) - return MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class DummyAgent(MCPAgent): - config_cls = DummyConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for the dummy agent.""" - return AgentType.OPENAI - - def __init__(self, **kwargs: Any) -> None: - params = DummyCreateParams(**kwargs) - super().__init__(params) - - async def get_system_messages(self) -> list[types.ContentBlock]: - return [types.TextContent(type="text", text="sys")] - - async def get_response(self, messages: list[Any]) -> InferenceResult: - return InferenceResult(content="ok", tool_calls=[], done=True) - - async def format_blocks(self, blocks: list[Any]) -> list[Any]: - return blocks - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[Any]: - return [types.TextContent(text="tools", type="text")] - - -def test_get_available_tools_before_run_raises() -> None: - """Test that get_available_tools raises before initialization.""" - agent = DummyAgent() - with pytest.raises(RuntimeError): - agent.get_available_tools() - - -@pytest.mark.asyncio -async def test_format_message_invalid_type_raises() -> None: - """Test that format_message raises for invalid types.""" - agent = DummyAgent() - with pytest.raises(ValueError): - await agent.format_message({"oops": 1}) # type: ignore - - -def test_text_to_blocks_shapes() -> None: - """Test text_to_blocks returns correct structure.""" - blocks = text_to_blocks("x") - assert isinstance(blocks, list) and blocks and isinstance(blocks[0], types.TextContent) - - -@pytest.mark.asyncio -async def test_run_with_eval_context() -> None: - """Test basic run() with EvalContext.""" - ctx = MockEvalContext(prompt="hello") - agent = DummyAgent() - result = await agent.run(ctx, max_steps=1) - assert result.done is True - assert result.isError is False - - -@pytest.mark.asyncio -async def test_run_requires_eval_context() -> None: - """Test run() raises TypeError for non-EvalContext.""" - agent = DummyAgent() - with pytest.raises(TypeError, match="must be EvalContext"): - await agent.run("hello") # type: ignore - - -@pytest.mark.asyncio -async def test_run_requires_prompt() -> None: - """Test run() raises ValueError when prompt is empty.""" - ctx = MockEvalContext(prompt="") - agent = DummyAgent() - with pytest.raises(ValueError, match="prompt is not set"): - await agent.run(ctx) - - -@pytest.mark.asyncio -async def test_call_tools_error_paths() -> None: - """Test call_tools handles errors correctly.""" - call_count = [0] - ok_result = MCPToolResult(content=text_to_blocks("ok"), isError=False) - - def handler(tool_call: MCPToolCall) -> MCPToolResult: - call_count[0] += 1 - if call_count[0] == 1: - return ok_result - raise RuntimeError("boom") - - ctx = MockEvalContext(prompt="test") - ctx.set_call_tool_handler(handler) - agent = DummyAgent() - - # Initialize the agent with context - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - results = await agent.call_tools( - [MCPToolCall(name="a", arguments={}), MCPToolCall(name="b", arguments={})] - ) - assert results[0].isError is False - assert results[1].isError is True - - -@pytest.mark.asyncio -async def test_call_tools_timeout_raises() -> None: - """Test call_tools raises TimeoutError.""" - - def handler(tool_call: MCPToolCall) -> MCPToolResult: - raise TimeoutError("timeout") - - ctx = MockEvalContext(prompt="test") - ctx.set_call_tool_handler(handler) - agent = DummyAgent() - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - with pytest.raises(TimeoutError): - await agent.call_tools(MCPToolCall(name="x", arguments={})) - - -@pytest.mark.asyncio -async def test_get_available_tools_after_run() -> None: - """Test get_available_tools works after initialization.""" - tools = [types.Tool(name="test_tool", description="Test", inputSchema={})] - ctx = MockEvalContext(prompt="hello", tools=tools) - agent = DummyAgent() - - # Run initializes the agent - await agent.run(ctx, max_steps=1) - - # After cleanup, we can't access tools (ctx is cleared) - # But during run, tools were available - assert agent._initialized is True diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py deleted file mode 100644 index fb3dab557..000000000 --- a/hud/agents/tests/test_claude.py +++ /dev/null @@ -1,1605 +0,0 @@ -"""Tests for Claude MCP Agent implementation.""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from anthropic import AsyncAnthropic, AsyncAnthropicBedrock -from mcp import types - -from hud.agents.claude import ( - ClaudeAgent, - base64_to_content_block, - text_to_content_block, - tool_use_content_block, -) -from hud.environment import Environment -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.eval.task import Task -from hud.types import MCPToolCall, MCPToolResult - -if TYPE_CHECKING: - from collections.abc import Generator - - from anthropic.types.beta import BetaMessageParam - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing.""" - - def __init__( - self, - tools: list[types.Tool] | None = None, - environment_capabilities: dict[str, Any] | None = None, - ) -> None: - # Core attributes - self.prompt = "Test prompt" - self._tools = tools or [] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - - # Environment attributes - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.scenario_enable_citations: bool = False - self.scenario_returns_schema: dict[str, Any] | None = None - self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} - self.environment_capabilities = environment_capabilities - self.results: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return False - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - return MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class MockStreamContextManager: - """Mock for Claude's streaming context manager.""" - - def __init__(self, response: MagicMock) -> None: - self.response = response - - async def __aenter__(self) -> MockStreamContextManager: - return self - - async def __aexit__( - self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any - ) -> bool: - return False - - def __aiter__(self) -> MockStreamContextManager: - return self - - async def __anext__(self) -> None: - raise StopAsyncIteration - - async def get_final_message(self) -> MagicMock: - return self.response - - -class MockErrorStreamContextManager: - """Mock stream context manager that raises a fixed error while streaming.""" - - def __init__(self, error: Exception) -> None: - self.error = error - - async def __aenter__(self) -> MockErrorStreamContextManager: - return self - - async def __aexit__( - self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any - ) -> bool: - return False - - def __aiter__(self) -> MockErrorStreamContextManager: - return self - - async def __anext__(self) -> None: - raise self.error - - async def get_final_message(self) -> MagicMock: - raise AssertionError("get_final_message should not be called when stream iteration fails") - - -class TestClaudeHelperFunctions: - """Test helper functions for Claude message formatting.""" - - def test_base64_to_content_block(self) -> None: - """Test base64 image conversion.""" - base64_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk" - result = base64_to_content_block(base64_data) - - assert result["type"] == "image" - assert result["source"]["type"] == "base64" - assert result["source"]["media_type"] == "image/png" - assert result["source"]["data"] == base64_data - - def test_text_to_content_block(self) -> None: - """Test text conversion.""" - text = "Hello, world!" - result = text_to_content_block(text) - - assert result["type"] == "text" - assert result["text"] == text - - def test_tool_use_content_block(self) -> None: - """Test tool result content block creation.""" - tool_use_id = "tool_123" - content = [text_to_content_block("Result text")] - - result = tool_use_content_block(tool_use_id, content) - - assert result["type"] == "tool_result" - assert result["tool_use_id"] == tool_use_id - assert result["content"] == content # type: ignore - - -class TestClaudeAgent: - """Test ClaudeAgent class.""" - - @pytest.fixture - def mock_anthropic(self) -> Generator[AsyncAnthropic, None, None]: # type: ignore[misc] - """Create a stub Anthropic client.""" - with patch("hud.agents.claude.agent.AsyncAnthropic") as mock_class: - client = MagicMock(spec=AsyncAnthropic) - client.api_key = "test-key" - mock_class.return_value = client - yield client # type: ignore[misc] - - @pytest.mark.asyncio - async def test_init_with_client(self, mock_anthropic: AsyncAnthropic) -> None: - """Test agent initialization with provided client.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-6", - validate_api_key=False, - ) - - assert agent.model_name == "Claude" - assert agent.config.model == "claude-sonnet-4-6" - assert agent.anthropic_client == mock_anthropic - - @pytest.mark.asyncio - async def test_init_with_parameters(self, mock_anthropic: AsyncAnthropic) -> None: - """Test agent initialization with various parameters.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-6", - max_tokens=4096, - validate_api_key=False, - ) - - assert agent.max_tokens == 4096 - - @pytest.mark.asyncio - async def test_format_blocks_text_only(self, mock_anthropic: AsyncAnthropic) -> None: - """Test formatting text content blocks.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Hello, world!"), - types.TextContent(type="text", text="How are you?"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - assert messages[0]["role"] == "user" - content = messages[0]["content"] - assert isinstance(content, list) - assert len(content) == 2 - assert content[0]["type"] == "text" # type: ignore[index] - assert content[0]["text"] == "Hello, world!" # type: ignore[index] - - @pytest.mark.asyncio - async def test_format_blocks_with_image(self, mock_anthropic: AsyncAnthropic) -> None: - """Test formatting image content blocks.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Look at this:"), - types.ImageContent(type="image", data="base64data", mimeType="image/png"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - content = messages[0]["content"] - assert isinstance(content, list) - assert len(content) == 2 - assert content[1]["type"] == "image" # type: ignore[index] - - @pytest.mark.asyncio - async def test_format_tool_results_text(self, mock_anthropic: AsyncAnthropic) -> None: - """Test formatting tool results with text content.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Tool output")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - assert messages[0]["role"] == "user" - content = messages[0]["content"] - assert isinstance(content, list) - assert len(content) == 1 - assert content[0]["type"] == "tool_result" # type: ignore[index] - assert content[0]["tool_use_id"] == "call_123" # type: ignore[index] - - @pytest.mark.asyncio - async def test_format_tool_results_with_error(self, mock_anthropic: AsyncAnthropic) -> None: - """Test formatting tool results with error.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Error message")], - isError=True, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - content = messages[0]["content"] - # Error content should include "Error:" prefix - assert any("Error" in str(block) for block in content[0]["content"]) # type: ignore[index] - - @pytest.mark.asyncio - async def test_get_system_messages(self, mock_anthropic: AsyncAnthropic) -> None: - """Test that system messages return empty (Claude uses system param).""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - system_prompt="You are a helpful assistant.", - validate_api_key=False, - ) - - messages = await agent.get_system_messages() - # Claude doesn't use system messages in the message list - assert messages == [] - - @pytest.mark.asyncio - async def test_get_response_with_thinking(self, mock_anthropic: AsyncAnthropic) -> None: - """Test getting model response with thinking content.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - # Set up agent as initialized - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - mock_response = MagicMock() - - thinking_block = MagicMock() - thinking_block.type = "thinking" - thinking_block.thinking = "Let me analyze this problem..." - - text_block = MagicMock() - text_block.type = "text" - text_block.text = "Here is the answer" - - mock_response.content = [thinking_block, text_block] - mock_response.usage = MagicMock(input_tokens=10, output_tokens=30) - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hard question"}]}, - ) - ] - response = await agent.get_response(messages) - - assert response.content == "Here is the answer" - assert response.reasoning == "Let me analyze this problem..." - - @pytest.mark.asyncio - async def test_convert_tools_for_claude(self, mock_anthropic: AsyncAnthropic) -> None: - """Test converting MCP tools to Claude format.""" - tools = [ - types.Tool( - name="my_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Check that tools were converted - assert len(agent.claude_tools) == 1 - assert agent.claude_tools[0]["name"] == "my_tool" # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_computer_tool_detection(self, mock_anthropic: AsyncAnthropic) -> None: - """Test that computer tools are detected for beta API.""" - tools = [ - types.Tool( - name="computer", - description="Control computer", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.has_computer_tool is True - - @pytest.mark.asyncio - async def test_computer_name_activates_agent_side_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude native computer calls route through the agent-side tool.""" - tools = [ - types.Tool( - name="computer", - description="HUD computer", - inputSchema={ - "type": "object", - "properties": {"action": {"type": "string"}, "x": {"type": "integer"}}, - }, - _meta={"resolution": {"width": 1280, "height": 720}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="clicked")], - isError=False, - ) - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - results = await agent.call_tools( - MCPToolCall( - name="computer", - arguments={"action": "left_click", "coordinate": [10, 20]}, - ) - ) - - assert results[0].isError is False - called = ctx.call_tool.call_args.args[0] - assert called.name == "computer" - assert called.arguments == { - "action": "click", - "x": 10, - "y": 20, - "hold_keys": None, - } - - @pytest.mark.asyncio - async def test_env_level_capability_activates_agent_side_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Env-level capabilities are the preferred binding source.""" - tools = [ - types.Tool( - name="desktop", - description="Computer", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext( - tools=tools, - environment_capabilities={ - "capabilities": { - "computer": { - "tool": "desktop", - "resolution": {"width": 1600, "height": 900}, - } - } - }, - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["display_width_px"] == 1600 # type: ignore[typeddict-item] - assert agent.claude_tools[0]["display_height_px"] == 900 # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_anthropic_computer_registration_uses_role_as_capability( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Old Claude native metadata acts only as a capability signal.""" - tools = [ - types.Tool( - name="anthropic_computer", - description="Anthropic computer", - inputSchema={ - "type": "object", - "properties": { - "action": {"type": "string"}, - "x": {"type": "integer"}, - "y": {"type": "integer"}, - }, - }, - _meta={ - "native_tools": { - "claude": { - "api_type": "stale_env_computer_spec", - "api_name": "computer", - "beta": "stale-env-beta", - "role": "computer", - "display_width": 1920, - "display_height": 1080, - } - } - }, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="clicked")], - isError=False, - ) - ) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-6", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["type"] == "computer_20251124" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["display_width_px"] != 1920 # type: ignore[typeddict-item] - assert agent.claude_tools[0]["display_height_px"] != 1080 # type: ignore[typeddict-item] - assert agent.claude_tools[0]["display_number"] == 1 # type: ignore[typeddict-item] - assert agent.claude_tools[0]["enable_zoom"] is True # type: ignore[typeddict-item] - assert agent._required_betas == {"computer-use-2025-11-24"} - - await agent.call_tools( - MCPToolCall( - name="computer", - arguments={"action": "left_click", "coordinate": [10, 20]}, - ) - ) - - called = ctx.call_tool.call_args.args[0] - assert called.name == "anthropic_computer" - assert called.arguments == { - "action": "click", - "x": 10, - "y": 20, - "hold_keys": None, - } - - @pytest.mark.asyncio - async def test_computer_translates_modifiers_drag_and_hold_key( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude computer actions translate to valid generic environment calls.""" - tools = [ - types.Tool( - name="computer", - description="Computer", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - calls: list[MCPToolCall] = [] - - async def call_tool(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - - ctx.call_tool = call_tool # type: ignore[method-assign] - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - await agent.call_tools( - [ - MCPToolCall( - name="computer", - arguments={ - "action": "right_click", - "coordinate": [10, 20], - "text": "Shift", - }, - ), - MCPToolCall( - name="computer", - arguments={"action": "left_click_drag", "coordinate": [30, 40]}, - ), - MCPToolCall( - name="computer", - arguments={"action": "hold_key", "text": "Control", "duration": 0.5}, - ), - ] - ) - - assert [call.arguments for call in calls] == [ - { - "action": "click", - "x": 10, - "y": 20, - "button": "right", - "hold_keys": ["shift"], - }, - {"action": "mouse_down", "button": "left"}, - {"action": "move", "x": 30, "y": 40}, - {"action": "mouse_up", "button": "left"}, - {"action": "hold_key", "text": "ctrl", "duration": 0.5}, - ] - - @pytest.mark.asyncio - async def test_bash_name_activates_agent_side_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude native bash calls route through the agent-side tool.""" - tools = [ - types.Tool( - name="bash", - description="Bash shell", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "bash" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["type"] == "bash_20250124" # type: ignore[typeddict-item] - - results = await agent.call_tools(MCPToolCall(name="bash", arguments={"command": "echo ok"})) - - assert results[0].isError is False - called = ctx.call_tool.call_args.args[0] - assert called.name == "bash" - assert called.arguments == {"command": "echo ok"} - - @pytest.mark.asyncio - async def test_bash_restart_matches_anthropic_contract( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude bash supports restart without command.""" - tools = [ - types.Tool( - name="bash", - description="Bash shell", - inputSchema={"type": "object", "properties": {}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="Bash session restarted.")], - isError=False, - ) - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - results = await agent.call_tools(MCPToolCall(name="bash", arguments={"restart": True})) - - assert results[0].isError is False - called = ctx.call_tool.call_args.args[0] - assert called.name == "bash" - assert called.arguments == {"restart": True} - - @pytest.mark.asyncio - async def test_bash_requires_command_unless_restart( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Malformed Claude bash calls fail before reaching the environment.""" - tools = [ - types.Tool( - name="bash", - description="Bash shell", - inputSchema={"type": "object", "properties": {}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock() - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - results = await agent.call_tools(MCPToolCall(name="bash", arguments={})) - - assert results[0].isError is True - assert "command is required" in results[0].content[0].text # type: ignore[attr-defined] - ctx.call_tool.assert_not_called() - - @pytest.mark.asyncio - async def test_edit_name_activates_agent_side_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude native editor calls route through the environment edit tool.""" - tools = [ - types.Tool( - name="edit", - description="File editor", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="edited")], - isError=False, - ) - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "str_replace_based_edit_tool" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["type"] == "text_editor_20250728" # type: ignore[typeddict-item] - - results = await agent.call_tools( - MCPToolCall( - name="str_replace_based_edit_tool", - arguments={ - "command": "str_replace", - "path": "/tmp/file.txt", - "old_str": "old", - "new_str": "new", - }, - ) - ) - - assert results[0].isError is False - called = ctx.call_tool.call_args.args[0] - assert called.name == "edit" - assert called.arguments == { - "command": "replace", - "path": "/tmp/file.txt", - "old_text": "old", - "new_text": "new", - } - - @pytest.mark.asyncio - async def test_claude_3_7_sonnet_editor_stays_generic( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude 3.7 Sonnet editor support is intentionally not advertised.""" - tools = [ - types.Tool( - name="edit", - description="File editor", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-3-7-sonnet-20250219", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert "str_replace_editor" not in agent._claude_native_tools - assert "str_replace_based_edit_tool" not in agent._claude_native_tools - assert agent.claude_tools[0]["name"] == "edit" # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_sonnet_4_5_uses_current_native_coding_tools( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Sonnet 4.5 keeps native bash and editor support for compatibility.""" - tools = [ - types.Tool( - name="bash", - description="Bash shell", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ), - types.Tool( - name="edit", - description="File editor", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ), - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-5", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - tool_types = {tool["name"]: tool.get("type") for tool in agent.claude_tools} # type: ignore[index] - assert tool_types["bash"] == "bash_20250124" - assert tool_types["str_replace_based_edit_tool"] == "text_editor_20250728" - - @pytest.mark.asyncio - async def test_sonnet_4_5_uses_20250124_native_computer_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Sonnet 4.5 keeps native computer support on its compatible spec.""" - tools = [ - types.Tool( - name="computer", - description="Computer", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-5", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "computer" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["type"] == "computer_20250124" # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_20250728_editor_rejects_unsupported_commands( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude 4 editor shape only forwards commands supported by the provider spec.""" - tools = [ - types.Tool( - name="edit", - description="File editor", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock() - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - results = await agent.call_tools( - MCPToolCall( - name="str_replace_based_edit_tool", - arguments={"command": "undo_edit", "path": "/tmp/file.txt"}, - ) - ) - - assert results[0].isError is True - assert "does not support command 'undo_edit'" in results[0].content[0].text # type: ignore[attr-defined] - results = await agent.call_tools( - MCPToolCall( - name="str_replace_based_edit_tool", - arguments={"command": "undo", "path": "/tmp/file.txt"}, - ) - ) - assert results[0].isError is True - assert "does not support command 'undo'" in results[0].content[0].text # type: ignore[attr-defined] - ctx.call_tool.assert_not_called() - - @pytest.mark.asyncio - async def test_memory_name_activates_agent_side_tool( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Claude native memory calls route through the environment memory tool.""" - tools = [ - types.Tool( - name="memory", - description="Memory", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - ctx.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="remembered")], - isError=False, - ) - ) - agent = ClaudeAgent.create(model_client=mock_anthropic, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "memory" # type: ignore[typeddict-item] - assert agent.claude_tools[0]["type"] == "memory_20250818" # type: ignore[typeddict-item] - assert agent._required_betas == set() - - results = await agent.call_tools( - MCPToolCall(name="memory", arguments={"command": "view", "path": "/"}) - ) - - assert results[0].isError is False - called = ctx.call_tool.call_args.args[0] - assert called.name == "memory" - assert called.arguments == {"command": "view", "path": "/"} - - @pytest.mark.asyncio - async def test_old_sonnet_memory_stays_generic(self, mock_anthropic: AsyncAnthropic) -> None: - """Claude memory is only advertised natively for the supported current models.""" - tools = [ - types.Tool( - name="memory", - description="Memory", - inputSchema={"type": "object", "properties": {"command": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - model="claude-sonnet-4-5", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent.claude_tools[0]["name"] == "memory" # type: ignore[typeddict-item] - assert "type" not in agent.claude_tools[0] # type: ignore[operator] - assert "memory" not in agent._claude_native_tools - - @pytest.mark.asyncio - async def test_get_response_with_text(self, mock_anthropic: AsyncAnthropic) -> None: - """Test getting response with text output.""" - # Create mock response - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Hello!")] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - response = await agent.get_response([]) - assert response.content == "Hello!" - assert response.done is True - assert len(response.tool_calls) == 0 - - @pytest.mark.asyncio - async def test_get_response_with_tool_call(self, mock_anthropic: AsyncAnthropic) -> None: - """Test getting response with tool call.""" - mock_tool_use = MagicMock() - mock_tool_use.type = "tool_use" - mock_tool_use.id = "call_123" - mock_tool_use.name = "my_tool" - mock_tool_use.input = {"x": "value"} - - mock_response = MagicMock() - mock_response.content = [mock_tool_use] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {"my_tool": "my_tool"} - agent.has_computer_tool = False - agent._initialized = True - - response = await agent.get_response([]) - assert response.done is False - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "my_tool" - assert response.tool_calls[0].arguments == {"x": "value"} - - @pytest.mark.asyncio - async def test_get_response_retries_same_generation_once_on_invalid_streamed_tool_json( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """First invalid streamed tool JSON should retry without adding guidance.""" - invalid_json_error = ValueError( - "Unable to parse tool parameter JSON from model. Please retry your request or " - "adjust your " - 'prompt. Error: expected value at line 1 column 10. JSON: {"labels": bug}' - ) - first_stream = MockErrorStreamContextManager(invalid_json_error) - - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Recovered")] - second_stream = MockStreamContextManager(mock_response) - - mock_anthropic.beta.messages.stream = MagicMock(side_effect=[first_stream, second_stream]) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - messages: list[BetaMessageParam] = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Create a Linear ticket"}]}, - ) - ] - - response = await agent.get_response(messages) - - assert response.content == "Recovered" - assert mock_anthropic.beta.messages.stream.call_count == 2 - # Original user message + assistant response (no guidance message needed) - assert len(messages) == 2 - assert messages[1]["role"] == "assistant" - - @pytest.mark.asyncio - async def test_get_response_adds_invalid_json_guidance_after_second_failure( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Second consecutive invalid JSON failure should add INVALID_JSON guidance.""" - invalid_json_error = ValueError( - "Unable to parse tool parameter JSON from model. Please retry your request or " - "adjust your " - 'prompt. Error: expected value at line 1 column 10. JSON: {"labels": bug}' - ) - first_stream = MockErrorStreamContextManager(invalid_json_error) - second_stream = MockErrorStreamContextManager(invalid_json_error) - - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Recovered after guidance")] - third_stream = MockStreamContextManager(mock_response) - - mock_anthropic.beta.messages.stream = MagicMock( - side_effect=[first_stream, second_stream, third_stream] - ) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - messages: list[BetaMessageParam] = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Create a Linear ticket"}]}, - ) - ] - - response = await agent.get_response(messages) - - assert response.content == "Recovered after guidance" - assert mock_anthropic.beta.messages.stream.call_count == 3 - # Original user message + INVALID_JSON guidance + assistant response - assert len(messages) == 3 - retry_message = messages[1] - assert retry_message["role"] == "user" - retry_content = cast("list[dict[str, Any]]", retry_message["content"]) - assert "INVALID_JSON" in retry_content[0]["text"] - - @pytest.mark.asyncio - async def test_get_response_does_not_retry_unrelated_value_error( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Non-tool-json ValueErrors should propagate immediately.""" - unrelated_error = ValueError("stream exploded for unrelated reason") - mock_anthropic.beta.messages.stream = MagicMock( - return_value=MockErrorStreamContextManager(unrelated_error) - ) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - with pytest.raises(ValueError, match="unrelated reason"): - await agent.get_response([]) - - assert mock_anthropic.beta.messages.stream.call_count == 1 - - -class TestClaudeAgentBedrock: - """Test ClaudeAgent class with Bedrock.""" - - @pytest.fixture - def bedrock_client(self) -> AsyncAnthropicBedrock: - """Create a real AsyncAnthropicBedrock client and stub networked methods.""" - client = AsyncAnthropicBedrock( - aws_access_key="AKIATEST", - aws_secret_key="secret", - aws_region="us-east-1", - ) - # Stub the actual Bedrock call so tests are hermetic. - client.beta.messages.create = AsyncMock() - return client - - @pytest.mark.asyncio - async def test_init(self, bedrock_client: AsyncAnthropicBedrock) -> None: - """Test agent initialization.""" - agent = ClaudeAgent.create( - model_client=bedrock_client, - model="test-model-arn", - validate_api_key=False, - ) - - assert agent.model_name == "Claude" - assert agent.config.model == "test-model-arn" - assert agent.anthropic_client == bedrock_client - - @pytest.mark.asyncio - async def test_get_response_bedrock_uses_create_not_stream( - self, bedrock_client: AsyncAnthropicBedrock - ) -> None: - """Bedrock path must call messages.create() (Bedrock doesn't support stream()).""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - model_client=bedrock_client, - model="test-model-arn", - validate_api_key=False, - ) - - # Enable computer tool to verify betas list includes computer-use in Bedrock mode. - # In real usage, this beta is added by _convert_tools_for_claude when it detects - # a computer tool. Here we manually set both flags to simulate that. - agent.has_computer_tool = True - agent._required_betas.add("computer-use-2025-01-24") - - mock_response = MagicMock() - text_block = MagicMock() - text_block.type = "text" - text_block.text = "Hello from Bedrock" - mock_response.content = [text_block] - - bedrock_client.beta.messages.create.return_value = mock_response # type: ignore[union-attr] - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - response = await agent.get_response(messages) - - assert response.content == "Hello from Bedrock" - assert response.tool_calls == [] - - # Bedrock-specific behavior: uses create() and appends assistant message directly. - assert not hasattr(bedrock_client.beta.messages, "stream") - bedrock_client.beta.messages.create.assert_awaited_once() # type: ignore[union-attr] - assert len(messages) == 2 - assert messages[-1]["role"] == "assistant" - - # Ensure the Bedrock call shape is stable. - _, kwargs = bedrock_client.beta.messages.create.call_args # type: ignore[union-attr] - assert kwargs["model"] == "test-model-arn" - assert kwargs["tool_choice"] == {"type": "auto", "disable_parallel_tool_use": True} - assert "computer-use-2025-01-24" in kwargs["betas"] - - @pytest.mark.asyncio - async def test_get_response_bedrock_missing_boto3_raises_value_error( - self, bedrock_client: AsyncAnthropicBedrock - ) -> None: - """If boto3 isn't installed, Bedrock client import path should raise a clear ValueError.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - model_client=bedrock_client, - model="test-model-arn", - validate_api_key=False, - ) - - bedrock_client.beta.messages.create.side_effect = ModuleNotFoundError("boto3") # type: ignore[union-attr] - messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] - - with pytest.raises(ValueError, match=r"boto3 is required for AWS Bedrock"): - await agent.get_response(messages) # type: ignore - - def test_init_with_bedrock_client_does_not_require_anthropic_api_key( - self, bedrock_client: AsyncAnthropicBedrock - ) -> None: - """Providing model_client should bypass ANTHROPIC_API_KEY validation.""" - with patch("hud.settings.settings.anthropic_api_key", None): - agent = ClaudeAgent.create( - model_client=bedrock_client, - validate_api_key=False, - ) - assert agent.anthropic_client == bedrock_client - - -class TestClaudeAgentComputerTool20251124: - """Test ClaudeAgent with the new computer_20251124 tool type.""" - - @pytest.fixture - def mock_anthropic(self) -> Any: - from unittest.mock import MagicMock - - return MagicMock(spec=["messages", "beta"]) - - def test_no_fine_grained_streaming_beta(self, mock_anthropic: Any) -> None: - """Test that fine-grained-tool-streaming beta is no longer included.""" - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - assert "fine-grained-tool-streaming-2025-05-14" not in agent._required_betas - - -class TestClaudeAgentBetaHeader: - """Test that the Anthropic-Beta header is handled correctly.""" - - @pytest.fixture - def mock_anthropic(self) -> Any: - return MagicMock(spec=["messages", "beta"]) - - @pytest.mark.asyncio - async def test_empty_betas_sends_omit_not_empty_list(self, mock_anthropic: Any) -> None: - """When no tools require a beta, betas should be Omit() not [].""" - from anthropic import Omit - - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._required_betas = set() # No betas required - agent._initialized = True - - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Hello")] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - await agent.get_response(messages) - - _, kwargs = mock_anthropic.beta.messages.stream.call_args - assert isinstance(kwargs["betas"], Omit), ( - f"Expected Omit() when no betas required, got {type(kwargs['betas'])}" - ) - - @pytest.mark.asyncio - async def test_nonempty_betas_sends_list(self, mock_anthropic: Any) -> None: - """When tools require betas, betas should be a list of strings.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = True - agent._required_betas = {"computer-use-2025-01-24"} - agent._initialized = True - - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Hello")] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - await agent.get_response(messages) - - _, kwargs = mock_anthropic.beta.messages.stream.call_args - assert isinstance(kwargs["betas"], list) - assert "computer-use-2025-01-24" in kwargs["betas"] - - @pytest.mark.asyncio - async def test_generic_tools_only_no_beta_header(self, mock_anthropic: Any) -> None: - """Generic function tools should not produce a beta header.""" - with patch("hud.settings.settings.telemetry_enabled", False): - tools = [ - types.Tool( - name="my_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Generic tools should not add any betas - assert len(agent._required_betas) == 0 - - mock_response = MagicMock() - mock_response.content = [MagicMock(type="text", text="Hello")] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - from anthropic import Omit - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - await agent.get_response(messages) - - _, kwargs = mock_anthropic.beta.messages.stream.call_args - assert isinstance(kwargs["betas"], Omit) - - -class TestCitationExtraction: - """Test citation extraction from BetaTextBlock.citations (modern SDK path).""" - - @pytest.fixture - def mock_anthropic(self) -> AsyncAnthropic: - client = MagicMock(spec=AsyncAnthropic) - client.beta = MagicMock() - client.beta.messages = MagicMock() - return client - - @pytest.mark.asyncio - async def test_inline_citations_extracted_from_text_block( - self, mock_anthropic: AsyncAnthropic - ) -> None: - """Text blocks with inline citations should populate result.citations.""" - cit1 = MagicMock() - cit1.cited_text = "Revenue was $1M" - cit1.document_index = 0 - cit1.document_title = "financials.pdf" - cit1.start_char_index = 0 - cit1.end_char_index = 15 - - text_block = MagicMock() - text_block.type = "text" - text_block.text = "Revenue was $1M last quarter." - text_block.citations = [cit1] - - mock_response = MagicMock() - mock_response.content = [text_block] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - result = await agent.get_response([]) - - assert result.content == "Revenue was $1M last quarter." - assert len(result.citations) == 1 - assert result.citations[0]["text"] == "Revenue was $1M" - assert result.citations[0]["source"] == "0" - assert result.citations[0]["title"] == "financials.pdf" - assert result.citations[0]["start_index"] == 0 - assert result.citations[0]["end_index"] == 15 - - @pytest.mark.asyncio - async def test_no_citations_when_field_is_none(self, mock_anthropic: AsyncAnthropic) -> None: - """Text blocks without citations should not populate result.citations.""" - text_block = MagicMock() - text_block.type = "text" - text_block.text = "No citations here." - text_block.citations = None - - mock_response = MagicMock() - mock_response.content = [text_block] - - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - agent = ClaudeAgent.create( - model_client=mock_anthropic, - validate_api_key=False, - ) - agent.claude_tools = [] - agent.tool_mapping = {} - agent.has_computer_tool = False - agent._initialized = True - - result = await agent.get_response([]) - assert result.citations == [] - - -class TestDocumentBlockCitations: - """Test that document_to_content_block threads enable_citations.""" - - def test_citations_disabled_by_default(self) -> None: - from hud.agents.claude import document_to_content_block - - block = document_to_content_block(base64_data="AAAA") - assert "citations" not in block - - def test_citations_enabled(self) -> None: - from hud.agents.claude import document_to_content_block - - block = document_to_content_block(base64_data="AAAA", enable_citations=True) - assert block["citations"] == {"enabled": True} # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_format_tool_results_threads_citations_to_documents(self) -> None: - """When scenario_enable_citations is True, PDF document blocks become siblings with citations.""" # noqa: E501 - ctx = MockEvalContext() - ctx.scenario_enable_citations = True - - client = MagicMock(spec=AsyncAnthropic) - client.beta = MagicMock() - client.beta.messages = MagicMock() - agent = ClaudeAgent.create( - model_client=client, - validate_api_key=False, - ) - agent.ctx = ctx - agent._initialized = True - agent.claude_tools = [] - agent.tool_mapping = {} - - pdf_blob = "JVBERi0xLjQ=" - tool_calls = [MCPToolCall(id="call_1", name="get_doc", arguments={})] - tool_results = [ - MCPToolResult( - content=[ - types.EmbeddedResource( - type="resource", - resource=types.BlobResourceContents( - uri="file:///doc.pdf", # type: ignore[arg-type] - mimeType="application/pdf", - blob=pdf_blob, - ), - ) - ], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - content_blocks = cast("list[dict[str, Any]]", messages[0]["content"]) - tool_result_block = content_blocks[0] - assert tool_result_block["type"] == "tool_result" - assert tool_result_block["content"], "tool_result should contain the PDF block" - assert tool_result_block["content"][0]["type"] == "document" - doc_block = content_blocks[1] - assert doc_block["type"] == "document" - assert doc_block["citations"] == {"enabled": True} - - @pytest.mark.asyncio - async def test_format_tool_results_wraps_text_as_document_when_citations_enabled(self) -> None: - """Text tool results produce a sibling document block for citations.""" - ctx = MockEvalContext() - ctx.scenario_enable_citations = True - - client = MagicMock(spec=AsyncAnthropic) - client.beta = MagicMock() - client.beta.messages = MagicMock() - agent = ClaudeAgent.create( - model_client=client, - validate_api_key=False, - ) - agent.ctx = ctx - agent._initialized = True - agent.claude_tools = [] - agent.tool_mapping = {} - - tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Revenue was $1M last quarter.")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - content_blocks = cast("list[dict[str, Any]]", messages[0]["content"]) - tool_result_block = content_blocks[0] - assert tool_result_block["type"] == "tool_result" - text_block = tool_result_block["content"][0] - assert text_block["type"] == "text" - assert text_block["text"] == "Revenue was $1M last quarter." - doc_block = content_blocks[1] - assert doc_block["type"] == "document" - assert doc_block["source"]["type"] == "text" - assert doc_block["source"]["data"] == "Revenue was $1M last quarter." - assert doc_block["citations"] == {"enabled": True} - assert doc_block["title"] == "get_sales" - - @pytest.mark.asyncio - async def test_remote_task_setup_preserves_citations_for_tool_results( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Remote task setup should propagate enable_citations into Claude formatting.""" - env = Environment("test-env") - task = Task(env=env, scenario="remote-env:solve-task", args={}) - ctx = EvalContext.from_task(task) - - async def successful_get_prompt( - _name: str, _arguments: dict[str, str] | None = None - ) -> Any: - return SimpleNamespace( - messages=[ - SimpleNamespace( - role="user", - content=SimpleNamespace(text="Prompt"), - ) - ], - meta={ - "enable_citations": True, - "returns_schema": { - "type": "object", - "properties": {"summary": {"type": "string"}}, - }, - }, - ) - - monkeypatch.setattr(ctx, "get_prompt", successful_get_prompt) - monkeypatch.setattr(ctx._router, "get_prompt_connection", lambda _name: "remote") - - await ctx._run_task_scenario_setup() - - assert ctx.scenario_enable_citations is True - session = ctx._get_session() - assert session is not None - assert session.enable_citations is True - - client = MagicMock(spec=AsyncAnthropic) - client.beta = MagicMock() - client.beta.messages = MagicMock() - agent = ClaudeAgent.create( - model_client=client, - validate_api_key=False, - ) - agent.ctx = ctx - agent._initialized = True - agent.claude_tools = [] - agent.tool_mapping = {} - - tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Revenue was $1M last quarter.")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - content_blocks = cast("list[dict[str, Any]]", messages[0]["content"]) - doc_block = content_blocks[1] - - assert doc_block["type"] == "document" - assert doc_block["source"]["type"] == "text" - assert doc_block["source"]["data"] == "Revenue was $1M last quarter." - assert doc_block["citations"] == {"enabled": True} - - @pytest.mark.asyncio - async def test_format_tool_results_keeps_text_block_when_citations_disabled(self) -> None: - """Text tool results stay as plain text blocks when citations are off.""" - ctx = MockEvalContext() - ctx.scenario_enable_citations = False - - client = MagicMock(spec=AsyncAnthropic) - client.beta = MagicMock() - client.beta.messages = MagicMock() - agent = ClaudeAgent.create( - model_client=client, - validate_api_key=False, - ) - agent.ctx = ctx - agent._initialized = True - agent.claude_tools = [] - agent.tool_mapping = {} - - tool_calls = [MCPToolCall(id="call_1", name="get_sales", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Revenue was $1M.")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - content_blocks = cast("list[dict[str, Any]]", messages[0]["content"]) - tool_result_block = content_blocks[0] - text_block = tool_result_block["content"][0] - assert text_block["type"] == "text" - assert text_block["text"] == "Revenue was $1M." diff --git a/hud/agents/tests/test_gateway_resolution.py b/hud/agents/tests/test_gateway_resolution.py new file mode 100644 index 000000000..ab016d40a --- /dev/null +++ b/hud/agents/tests/test_gateway_resolution.py @@ -0,0 +1,197 @@ +"""HUD gateway agent resolution tests.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from hud.agents import OpenAIAgent, create_agent +from hud.agents.claude import ClaudeAgent +from hud.agents.gateway import GatewayModelsResponse, build_gateway_client +from hud.agents.openai_compatible import OpenAIChatAgent + +MODELS = GatewayModelsResponse.model_validate( + { + "models": [ + { + "id": "uuid-openai", + "name": "GPT 5.4", + "model_name": "gpt-5.4", + "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"}, + }, + { + "id": "uuid-claude", + "name": "Claude Sonnet 4.6", + "model_name": "claude-sonnet-4-6", + "provider": {"name": "Anthropic", "default_sdk_agent_type": "claude"}, + }, + { + "id": "uuid-grok", + "name": "Grok 4.1 Fast", + "model_name": "grok-4-1-fast", + "provider": {"name": "xAI", "default_sdk_agent_type": "openai_compatible"}, + }, + { + "id": "uuid-operator", + "name": "Operator", + "model_name": "computer-use-preview", + "sdk_agent_type": "operator", + "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"}, + }, + { + "id": "uuid-gemini-cua", + "name": "Gemini Computer Use", + "model_name": "gemini-2.5-computer-use-preview", + "sdk_agent_type": "gemini_cua", + "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"}, + }, + ] + } +).models + + +def test_create_agent_resolves_gateway_model_to_provider_agent() -> None: + expected = MagicMock() + client = MagicMock() + with ( + patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS), + patch("hud.agents.gateway.build_gateway_client", return_value=client) as build_client, + patch.object(OpenAIAgent, "create", return_value=expected) as create, + ): + agent = create_agent("gpt-5.4", temperature=0.5) + + assert agent is expected + build_client.assert_called_once_with("OpenAI") + create.assert_called_once() + assert create.call_args.kwargs["model"] == "gpt-5.4" + assert create.call_args.kwargs["model_client"] is client + assert create.call_args.kwargs["temperature"] == 0.5 + + +@pytest.mark.parametrize("model_alias", ["uuid-openai", "GPT 5.4", "gpt-5.4"]) +def test_create_agent_resolves_gateway_model_aliases(model_alias: str) -> None: + expected = MagicMock() + with ( + patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS), + patch("hud.agents.gateway.build_gateway_client", return_value=MagicMock()), + patch.object(OpenAIAgent, "create", return_value=expected) as create, + ): + agent = create_agent(model_alias) + + assert agent is expected + assert create.call_args.kwargs["model"] == "gpt-5.4" + + +def test_create_agent_shortcut_uses_gateway_provider() -> None: + expected = MagicMock() + with ( + patch("hud.agents.gateway.build_gateway_client", return_value=MagicMock()) as build_client, + patch.object(ClaudeAgent, "create", return_value=expected), + ): + agent = create_agent("claude") + + assert agent is expected + build_client.assert_called_once_with("anthropic") + + +def test_create_agent_openai_compatible_models_use_chat_agent_client() -> None: + expected = MagicMock() + client = MagicMock() + with ( + patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS), + patch("hud.agents.gateway.build_gateway_client", return_value=client), + patch.object(OpenAIChatAgent, "create", return_value=expected) as create, + ): + agent = create_agent("grok-4-1-fast") + + assert agent is expected + assert create.call_args.kwargs["openai_client"] is client + assert "model_client" not in create.call_args.kwargs + + +@pytest.mark.parametrize( + ("model", "message"), + [ + ("missing-model", "not found"), + ("computer-use-preview", "Operator agent is no longer supported"), + ("gemini-2.5-computer-use-preview", "Gemini CUA agent is no longer supported"), + ], +) +def test_create_agent_rejects_unknown_or_stale_gateway_models(model: str, message: str) -> None: + with ( + patch("hud.agents.gateway._fetch_gateway_models", return_value=MODELS), + pytest.raises(ValueError, match=message), + ): + create_agent(model) + + +def test_create_agent_rejects_gateway_model_with_invalid_agent_metadata() -> None: + models = GatewayModelsResponse.model_validate( + { + "models": [ + { + "id": "bad-model", + "name": "Bad Model", + "model_name": "bad-model", + "provider": {"name": "OpenAI", "default_sdk_agent_type": None}, + } + ] + } + ).models + + with ( + patch("hud.agents.gateway._fetch_gateway_models", return_value=models), + pytest.raises(ValueError, match="invalid agent type metadata"), + ): + create_agent("bad-model") + + +def test_build_gateway_client_uses_openai_compatible_client_by_default() -> None: + with ( + patch("hud.agents.gateway.settings") as settings, + patch("hud.agents.gateway.AsyncOpenAI") as client_cls, + ): + settings.api_key = "hud-key" + settings.hud_gateway_url = "https://gateway.example" + + build_gateway_client("together") + + client_cls.assert_called_once_with( + api_key="hud-key", + base_url="https://gateway.example", + ) + + +def test_build_gateway_client_uses_anthropic_client_for_anthropic_provider() -> None: + with ( + patch("hud.agents.gateway.settings") as settings, + patch("anthropic.AsyncAnthropic") as client_cls, + ): + settings.api_key = "hud-key" + settings.hud_gateway_url = "https://gateway.example" + + build_gateway_client("anthropic") + + client_cls.assert_called_once_with( + api_key="hud-key", + base_url="https://gateway.example", + ) + + +def test_build_gateway_client_uses_genai_client_for_gemini_provider() -> None: + with ( + patch("hud.agents.gateway.settings") as settings, + patch("google.genai.Client") as client_cls, + ): + settings.api_key = "hud-key" + settings.hud_gateway_url = "https://gateway.example" + + build_gateway_client("gemini") + + client_cls.assert_called_once() + assert client_cls.call_args.kwargs["api_key"] == "PLACEHOLDER" + http_options = client_cls.call_args.kwargs["http_options"] + assert http_options.api_version == "v1beta" + assert http_options.base_url == "https://gateway.example" + assert http_options.headers == {"Authorization": "Bearer hud-key"} diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py deleted file mode 100644 index dcaa7f309..000000000 --- a/hud/agents/tests/test_gemini.py +++ /dev/null @@ -1,1064 +0,0 @@ -"""Tests for Gemini MCP Agent implementation.""" - -from __future__ import annotations - -import base64 -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from google import genai -from google.genai import types as genai_types -from mcp import types - -from hud.agents.gemini import GeminiAgent -from hud.agents.gemini.tools import GeminiComputerTool as AgentGeminiComputerTool -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.types import MCPToolCall, MCPToolResult - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing.""" - - def __init__(self, tools: list[types.Tool] | None = None) -> None: - # Core attributes - self.prompt = "Test prompt" - self._tools = tools or [] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - - # Environment attributes - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.error: BaseException | None = None - self.scenario_enable_citations: bool = False - self.scenario_returns_schema: dict[str, Any] | None = None - self.metadata: dict[str, Any] = {} - self.results: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return False - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - return MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class TestGeminiAgent: - """Test GeminiAgent base class.""" - - @pytest.fixture - def mock_gemini_client(self) -> MagicMock: - """Create a stub Gemini client.""" - client = MagicMock(spec=genai.Client) - client.api_key = "test_key" - client.models = MagicMock() - client.models.list = MagicMock(return_value=iter([])) - client.models.generate_content = MagicMock() - # Set up async interface (aio.models.generate_content) - client.aio = MagicMock() - client.aio.models = MagicMock() - client.aio.models.generate_content = AsyncMock() - return client - - @pytest.mark.asyncio - async def test_init(self, mock_gemini_client: MagicMock) -> None: - """Test agent initialization.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - model="gemini-2.5-flash", - validate_api_key=False, - ) - - assert agent.model_name == "Gemini" - assert agent.config.model == "gemini-2.5-flash" - assert agent.gemini_client == mock_gemini_client - - @pytest.mark.asyncio - async def test_init_without_model_client(self) -> None: - """Test agent initialization without model client.""" - with ( - patch("hud.settings.settings.gemini_api_key", "test_key"), - patch("hud.agents.gemini.agent.genai.Client") as mock_client_class, - ): - mock_client = MagicMock() - mock_client.api_key = "test_key" - mock_client.models = MagicMock() - mock_client.models.list = MagicMock(return_value=iter([])) - mock_client_class.return_value = mock_client - - agent = GeminiAgent.create( - model="gemini-2.5-flash", - validate_api_key=False, - ) - - assert agent.gemini_client is not None - - @pytest.mark.asyncio - async def test_format_blocks_text_only(self, mock_gemini_client: MagicMock) -> None: - """Test formatting text content blocks.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Hello, world!"), - types.TextContent(type="text", text="How are you?"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - assert messages[0].role == "user" - assert messages[0].parts is not None - assert len(messages[0].parts) == 2 - - @pytest.mark.asyncio - async def test_format_blocks_with_image(self, mock_gemini_client: MagicMock) -> None: - """Test formatting image content blocks.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - # Create a tiny valid base64 PNG - png_data = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode() - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Look at this:"), - types.ImageContent(type="image", data=png_data, mimeType="image/png"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - assert messages[0].parts is not None - assert len(messages[0].parts) == 2 - - @pytest.mark.asyncio - async def test_format_tool_results(self, mock_gemini_client: MagicMock) -> None: - """Test formatting tool results.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Tool output")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - assert messages[0].role == "user" - - @pytest.mark.asyncio - async def test_get_system_messages(self, mock_gemini_client: MagicMock) -> None: - """Test that system messages return empty (Gemini uses system_instruction).""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - system_prompt="You are a helpful assistant.", - validate_api_key=False, - ) - - messages = await agent.get_system_messages() - # Gemini doesn't use system messages in the message list - assert messages == [] - - @pytest.mark.asyncio - async def test_get_response_text_only(self, mock_gemini_client: MagicMock) -> None: - """Test getting text-only response.""" - # Disable telemetry for this test - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - # Set up agent as initialized (no tools needed for this test) - agent.gemini_tools = [] - agent._initialized = True - - # Mock the API response with text only - mock_response = MagicMock() - mock_candidate = MagicMock() - - text_part = MagicMock() - text_part.text = "Task completed successfully" - text_part.function_call = None - - mock_candidate.content = MagicMock() - mock_candidate.content.parts = [text_part] - - mock_response.candidates = [mock_candidate] - - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response) - - messages = [ - genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")]) - ] - response = await agent.get_response(messages) - - assert response.content == "Task completed successfully" - assert response.tool_calls == [] - assert response.done is True - - @pytest.mark.asyncio - async def test_get_response_raises_on_no_candidates( - self, mock_gemini_client: MagicMock - ) -> None: - """A no-candidate Gemini response should fail loudly, not submit an empty answer.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - model_client=mock_gemini_client, - model="gemini-3-flash-preview", - validate_api_key=False, - ) - agent.gemini_tools = [] - agent._initialized = True - - mock_response = MagicMock() - mock_response.candidates = [] - mock_response.prompt_feedback = "blocked" - mock_response.usage_metadata = None - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response) - - messages = [ - genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")]) - ] - - with pytest.raises(RuntimeError, match="returned no candidates"): - await agent.get_response(messages) - - @pytest.mark.asyncio - async def test_get_response_with_thinking(self, mock_gemini_client: MagicMock) -> None: - """Test getting response with thinking content.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - # Set up agent as initialized (no tools needed for this test) - agent.gemini_tools = [] - agent._initialized = True - - mock_response = MagicMock() - mock_candidate = MagicMock() - - thinking_part = MagicMock() - thinking_part.text = "Let me reason through this..." - thinking_part.function_call = None - thinking_part.thought = True - - text_part = MagicMock() - text_part.text = "Here is my answer" - text_part.function_call = None - text_part.thought = False - - mock_candidate.content = MagicMock() - mock_candidate.content.parts = [thinking_part, text_part] - - mock_response.candidates = [mock_candidate] - - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response) - - messages = [ - genai_types.Content( - role="user", parts=[genai_types.Part.from_text(text="Hard question")] - ) - ] - response = await agent.get_response(messages) - - assert response.content == "Here is my answer" - assert response.reasoning == "Let me reason through this..." - - @pytest.mark.asyncio - async def test_get_response_passes_thinking_config(self, mock_gemini_client: MagicMock) -> None: - """Gemini 3 thinking options should be passed to GenerateContentConfig.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - model_client=mock_gemini_client, - model="gemini-3-flash-preview", - validate_api_key=False, - thinking_level="high", - include_thoughts=True, - ) - agent.gemini_tools = [] - agent._initialized = True - - mock_response = MagicMock() - mock_candidate = MagicMock() - text_part = MagicMock() - text_part.text = "Answer" - text_part.function_call = None - text_part.thought = False - mock_candidate.content = MagicMock() - mock_candidate.content.parts = [text_part] - mock_response.candidates = [mock_candidate] - - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=mock_response) - - messages = [ - genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Hi")]) - ] - await agent.get_response(messages) - - config = mock_gemini_client.aio.models.generate_content.call_args.kwargs["config"] - assert config.thinking_config is not None - assert config.thinking_config.include_thoughts is True - assert config.thinking_config.thinking_level.value == "HIGH" - - @pytest.mark.asyncio - async def test_convert_tools_for_gemini(self, mock_gemini_client: MagicMock) -> None: - """Test converting MCP tools to Gemini format.""" - tools = [ - types.Tool( - name="my_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Check that tools were converted - assert len(agent.gemini_tools) == 1 - # Gemini tools have function_declarations - cast to genai Tool type - gemini_tool = agent.gemini_tools[0] - assert isinstance(gemini_tool, genai_types.Tool) - assert gemini_tool.function_declarations is not None - assert gemini_tool.function_declarations[0].name == "my_tool" - - @pytest.mark.asyncio - async def test_regular_agent_uses_native_computer_use( - self, mock_gemini_client: MagicMock - ) -> None: - """GeminiAgent should register GeminiComputerTool as native Computer Use.""" - computer_tool = types.Tool( - name="gemini_computer", - description="Control computer with mouse, keyboard, and screenshots", - inputSchema={"type": "object", "properties": {}}, - ) - computer_tool.meta = { - "native_tools": { - "gemini": { - "api_type": "computer_use", - "api_name": "gemini_computer", - "role": "computer", - "supported_models": ["gemini-3-flash-preview"], - } - } - } - tools = [ - computer_tool, - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - model="gemini-3-flash-preview", - validate_api_key=False, - excluded_predefined_functions=["drag_and_drop"], - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent._computer_tool_name == "computer_use" - assert agent._gemini_native_tools["computer_use"].env_tool_name == "gemini_computer" - assert "gemini_computer" not in agent._gemini_native_tools - assert len(agent.gemini_tools) == 1 - computer_tool = agent.gemini_tools[0] - assert isinstance(computer_tool, genai_types.Tool) - assert computer_tool.computer_use is not None - assert computer_tool.computer_use.excluded_predefined_functions == ["drag_and_drop"] - - @pytest.mark.asyncio - async def test_computer_use_excludes_colliding_generic_tool_names( - self, mock_gemini_client: MagicMock - ) -> None: - """Generic tools named like predefined actions should not be hijacked.""" - computer_tool = types.Tool( - name="gemini_computer", - description="Control computer with mouse, keyboard, and screenshots", - inputSchema={"type": "object", "properties": {}}, - ) - computer_tool.meta = { - "native_tools": { - "gemini": { - "api_type": "computer_use", - "api_name": "gemini_computer", - "role": "computer", - "supported_models": ["gemini-3-flash-preview"], - } - } - } - navigate_tool = types.Tool( - name="navigate", - description="A non-computer navigation helper", - inputSchema={"type": "object", "properties": {"url": {"type": "string"}}}, - ) - ctx = MockEvalContext(tools=[computer_tool, navigate_tool]) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - model="gemini-3-flash-preview", - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - computer_use_tool = next( - tool for tool in agent.gemini_tools if getattr(tool, "computer_use", None) is not None - ) - computer_use = getattr(computer_use_tool, "computer_use", None) - assert computer_use is not None - assert "navigate" in (computer_use.excluded_predefined_functions or []) - function_call = MagicMock() - function_call.name = "navigate" - function_call.args = {"url": "https://example.com"} - tool_call = agent._extract_tool_call(MagicMock(function_call=function_call)) - assert tool_call is not None - assert tool_call.name == "navigate" - assert tool_call.arguments == {"url": "https://example.com"} - - @pytest.mark.asyncio - async def test_agent_owns_gemini_cli_tool_surface(self, mock_gemini_client: MagicMock) -> None: - """GeminiAgent exposes Gemini-shaped tools backed by generic env primitives.""" - tools = [ - types.Tool(name="bash", description="Run shell", inputSchema={"type": "object"}), - types.Tool(name="edit", description="Edit files", inputSchema={"type": "object"}), - types.Tool(name="read", description="Read files", inputSchema={"type": "object"}), - types.Tool(name="grep", description="Search files", inputSchema={"type": "object"}), - types.Tool(name="glob", description="Find files", inputSchema={"type": "object"}), - types.Tool(name="list", description="List files", inputSchema={"type": "object"}), - types.Tool(name="memory", description="Remember facts", inputSchema={"type": "object"}), - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - agent.console.info = MagicMock() - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - declaration_names = { - declaration.name - for tool in agent.gemini_tools - for declaration in (getattr(tool, "function_declarations", None) or []) - } - assert { - "run_shell_command", - "replace", - "write_file", - "read_file", - "grep_search", - "glob", - "list_directory", - "save_memory", - } <= declaration_names - assert agent._gemini_native_tools["run_shell_command"].env_tool_name == "bash" - assert agent._gemini_native_tools["replace"].env_tool_name == "edit" - assert agent._gemini_native_tools["write_file"].env_tool_name == "edit" - assert agent._gemini_native_tools["read_file"].env_tool_name == "read" - assert agent._gemini_native_tools["grep_search"].env_tool_name == "grep" - assert agent._gemini_native_tools["glob"].env_tool_name == "glob" - assert agent._gemini_native_tools["list_directory"].env_tool_name == "list" - assert agent._gemini_native_tools["save_memory"].env_tool_name == "memory" - declarations = { - declaration.name: declaration - for tool in agent.gemini_tools - for declaration in (getattr(tool, "function_declarations", None) or []) - } - assert "allow_multiple" not in declarations["replace"].parameters_json_schema["properties"] - assert ( - "exclude_pattern" - not in declarations["grep_search"].parameters_json_schema["properties"] - ) - assert "names_only" not in declarations["grep_search"].parameters_json_schema["properties"] - assert "respect_git_ignore" not in declarations["glob"].parameters_json_schema["properties"] - agent.console.info.assert_called_with( - "Agent initialized with 8 tools: " - "glob, grep_search, list_directory, read_file, replace, run_shell_command, " - "save_memory, write_file" - ) - - @pytest.mark.asyncio - async def test_gemini_legacy_env_tools_activate_harness_tools( - self, mock_gemini_client: MagicMock - ) -> None: - """Old Gemini env constructors register canonical names for harness activation.""" - from hud.tools import ( - GeminiGlobTool, - GeminiListTool, - GeminiMemoryTool, - GeminiReadTool, - GeminiSearchTool, - ) - - env_tools = [ - GeminiReadTool(), - GeminiSearchTool(), - GeminiGlobTool(), - GeminiListTool(), - GeminiMemoryTool(), - ] - tools = [ - types.Tool(name=tool.name, description=tool.description, inputSchema={"type": "object"}) - for tool in env_tools - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert agent._gemini_native_tools["read_file"].env_tool_name == "read" - assert agent._gemini_native_tools["grep_search"].env_tool_name == "grep" - assert agent._gemini_native_tools["glob"].env_tool_name == "glob" - assert agent._gemini_native_tools["list_directory"].env_tool_name == "list" - assert agent._gemini_native_tools["save_memory"].env_tool_name == "memory" - - def test_regular_agent_routes_computer_use_function_call( - self, mock_gemini_client: MagicMock - ) -> None: - """Gemini Computer Use calls should route to the MCP computer tool.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - agent._computer_tool_name = "computer_use" - - function_call = MagicMock() - function_call.name = "click_at" - function_call.args = {"x": 500, "y": 250, "safety_decision": {"decision": "allowed"}} - part = MagicMock(function_call=function_call) - - tool_call = agent._extract_tool_call(part) - - assert tool_call is not None - assert tool_call.name == "computer_use" - assert tool_call.arguments == { - "action": "click_at", - "safety_decision": {"decision": "allowed"}, - "x": 500, - "y": 250, - } - assert getattr(tool_call, "gemini_name") == "click_at" - - def test_gemini_computer_drag_insets_edge_coordinates(self) -> None: - """Gemini drag endpoints should be inset before calling the environment tool.""" - spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview") - assert spec is not None - tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec) - - calls = tool._env_calls( - "drag_and_drop", - {"x": 0, "y": 500, "destination_x": 1000, "destination_y": 500}, - ) - - assert calls == [ - { - "action": "drag", - "path": [ - {"x": 25, "y": 500}, - {"x": 975, "y": 500}, - ], - } - ] - - def test_gemini_computer_normalizes_keys_and_optional_type_coordinates(self) -> None: - """Gemini key strings should map cleanly to the environment press contract.""" - spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview") - assert spec is not None - tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec) - - assert tool._env_calls("key_combination", {"keys": "Control+A"}) == [ - {"action": "press", "keys": ["ctrl", "a"]} - ] - assert tool._env_calls("type_text_at", {"text": "hello", "clear_before_typing": False}) == [ - {"action": "write", "text": "hello", "enter_after": False} - ] - - @pytest.mark.asyncio - async def test_gemini_computer_blocks_confirmation_required_actions(self) -> None: - """Gemini require_confirmation actions need HITL before execution.""" - spec = AgentGeminiComputerTool.default_spec("gemini-3-flash-preview") - assert spec is not None - tool = AgentGeminiComputerTool(env_tool_name="computer", spec=spec) - calls: list[MCPToolCall] = [] - - async def call_tool(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult( - content=[types.TextContent(type="text", text="executed")], - isError=False, - ) - - result = await tool.execute( - call_tool, - { - "action": "click_at", - "x": 10, - "y": 20, - "safety_decision": {"decision": "require_confirmation"}, - }, - ) - - assert result.isError is False - assert isinstance(result.content[0], types.TextContent) - assert result.content[0].text.startswith("__GEMINI_SAFETY_BLOCKED__:") - assert calls == [] - - @pytest.mark.asyncio - async def test_regular_agent_formats_computer_use_results( - self, mock_gemini_client: MagicMock - ) -> None: - """GeminiAgent should return URL and screenshot parts for native computer use.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - agent._computer_tool_name = "computer_use" - screenshot = base64.b64encode(b"png bytes").decode() - tool_calls = [ - MCPToolCall( - name="computer_use", - arguments={"action": "click_at", "safety_decision": {"decision": "allowed"}}, - gemini_name="click_at", # type: ignore[arg-type] - ) - ] - tool_results = [ - MCPToolResult( - content=[ - types.TextContent(type="text", text="__URL__:https://example.com"), - types.ImageContent(type="image", data=screenshot, mimeType="image/png"), - ], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - parts = messages[0].parts - assert parts is not None - function_response = parts[0].function_response - assert function_response is not None - assert function_response.name == "click_at" - response = function_response.response - assert response is not None - assert response["url"] == "https://example.com" - assert response["safety_acknowledgement"] is True - assert function_response.parts is not None - inline_data = function_response.parts[0].inline_data - assert inline_data is not None - assert inline_data.mime_type == "image/png" - - @pytest.mark.asyncio - async def test_regular_agent_formats_blocked_computer_use_results( - self, mock_gemini_client: MagicMock - ) -> None: - """Blocked Gemini safety actions should not be reported as tool errors.""" - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - agent._computer_tool_name = "computer_use" - tool_calls = [ - MCPToolCall( - name="computer_use", - arguments={ - "action": "click_at", - "safety_decision": {"decision": "require_confirmation"}, - }, - gemini_name="click_at", # type: ignore[arg-type] - ) - ] - tool_results = [ - MCPToolResult( - content=[ - types.TextContent( - type="text", - text=( - "__GEMINI_SAFETY_BLOCKED__:Gemini Computer Use action requires " - "user confirmation before execution." - ), - ), - ], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - parts = messages[0].parts - assert parts is not None - function_response = parts[0].function_response - assert function_response is not None - response = function_response.response - assert response is not None - assert response["blocked"] is True - assert "success" not in response - assert response["url"] == "about:blank" - assert "safety_acknowledgement" not in response - - -class TestGeminiToolConversion: - """Tests for tool conversion to Gemini format.""" - - @pytest.fixture - def mock_gemini_client(self) -> MagicMock: - """Create a stub Gemini client.""" - client = MagicMock(spec=genai.Client) - client.api_key = "test_key" - client.models = MagicMock() - client.models.list = MagicMock(return_value=iter([])) - # Set up async interface - client.aio = MagicMock() - client.aio.models = MagicMock() - client.aio.models.generate_content = AsyncMock() - return client - - @pytest.mark.asyncio - async def test_tool_with_properties(self, mock_gemini_client: MagicMock) -> None: - """Test tool with input properties.""" - tools = [ - types.Tool( - name="search", - description="Search the web", - inputSchema={ - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "limit": {"type": "integer", "description": "Max results"}, - }, - "required": ["query"], - }, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert len(agent.gemini_tools) == 1 - gemini_tool = agent.gemini_tools[0] - # Gemini tools have function_declarations - cast to genai Tool type - assert isinstance(gemini_tool, genai_types.Tool) - assert gemini_tool.function_declarations is not None - assert gemini_tool.function_declarations[0].name == "search" - assert gemini_tool.function_declarations[0].parameters_json_schema is not None - - @pytest.mark.asyncio - async def test_tool_without_schema(self, mock_gemini_client: MagicMock) -> None: - """Test tool without description raises error.""" - # Create a tool with inputSchema but no description - tools = [ - types.Tool( - name="incomplete", - description=None, - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = GeminiAgent.create( - model_client=mock_gemini_client, - validate_api_key=False, - ) - - agent.ctx = ctx - with pytest.raises(ValueError, match="requires both a description"): - await agent._initialize_from_ctx(ctx) - - -class TestGeminiCitations: - """Tests for Gemini grounding citation extraction.""" - - @pytest.fixture - def mock_gemini_client(self) -> MagicMock: - client = MagicMock(spec=genai.Client) - client.aio = MagicMock() - client.aio.models = MagicMock() - client.aio.models.generate_content = AsyncMock() - return client - - def _make_agent(self, client: MagicMock) -> GeminiAgent: - agent = GeminiAgent.create(model_client=client, validate_api_key=False) - agent.gemini_tools = [] - agent._initialized = True - return agent - - def _text_candidate(self, text: str = "answer") -> MagicMock: - candidate = MagicMock() - part = MagicMock() - part.text = text - part.function_call = None - part.thought = False - candidate.content = MagicMock() - candidate.content.parts = [part] - return candidate - - @pytest.mark.asyncio - async def test_no_grounding_metadata(self, mock_gemini_client: MagicMock) -> None: - """No citations when groundingMetadata is absent.""" - agent = self._make_agent(mock_gemini_client) - candidate = self._text_candidate() - candidate.grounding_metadata = None - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - result = await agent.get_response([]) - assert result.citations == [] - - @pytest.mark.asyncio - async def test_grounding_chunks_only(self, mock_gemini_client: MagicMock) -> None: - """Chunks without supports produce citations with source but no anchoring.""" - agent = self._make_agent(mock_gemini_client) - candidate = self._text_candidate() - - chunk = MagicMock() - chunk.web = MagicMock() - chunk.web.uri = "https://example.com" - chunk.web.title = "Example" - - grounding_meta = MagicMock() - grounding_meta.grounding_chunks = [chunk] - grounding_meta.grounding_supports = [] - candidate.grounding_metadata = grounding_meta - - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - result = await agent.get_response([]) - assert len(result.citations) == 1 - assert result.citations[0]["source"] == "https://example.com" - assert result.citations[0]["title"] == "Example" - assert result.citations[0]["text"] == "" - - @pytest.mark.asyncio - async def test_grounding_supports_with_anchoring(self, mock_gemini_client: MagicMock) -> None: - """Supports produce citations with start_index/end_index from segments.""" - agent = self._make_agent(mock_gemini_client) - candidate = self._text_candidate("The sky is blue because of Rayleigh scattering.") - - chunk = MagicMock() - chunk.web = MagicMock() - chunk.web.uri = "https://physics.org/scattering" - chunk.web.title = "Scattering" - - support = MagicMock() - support.segment = MagicMock() - support.segment.text = "Rayleigh scattering" - support.segment.start_index = 28 - support.segment.end_index = 47 - support.grounding_chunk_indices = [0] - - grounding_meta = MagicMock() - grounding_meta.grounding_chunks = [chunk] - grounding_meta.grounding_supports = [support] - candidate.grounding_metadata = grounding_meta - - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - result = await agent.get_response([]) - assert len(result.citations) == 1 - cit = result.citations[0] - assert cit["type"] == "grounding" - assert cit["text"] == "Rayleigh scattering" - assert cit["source"] == "https://physics.org/scattering" - assert cit["start_index"] == 28 - assert cit["end_index"] == 47 - - @pytest.mark.asyncio - async def test_multiple_supports_and_chunks(self, mock_gemini_client: MagicMock) -> None: - """Multiple supports across multiple chunks produce the right citations.""" - agent = self._make_agent(mock_gemini_client) - candidate = self._text_candidate() - - chunk_a = MagicMock() - chunk_a.web = MagicMock() - chunk_a.web.uri = "https://a.com" - chunk_a.web.title = "A" - - chunk_b = MagicMock() - chunk_b.web = MagicMock() - chunk_b.web.uri = "https://b.com" - chunk_b.web.title = "B" - - support1 = MagicMock() - support1.segment = MagicMock() - support1.segment.text = "fact one" - support1.segment.start_index = 0 - support1.segment.end_index = 8 - support1.grounding_chunk_indices = [0] - - support2 = MagicMock() - support2.segment = MagicMock() - support2.segment.text = "fact two" - support2.segment.start_index = 10 - support2.segment.end_index = 18 - support2.grounding_chunk_indices = [1] - - grounding_meta = MagicMock() - grounding_meta.grounding_chunks = [chunk_a, chunk_b] - grounding_meta.grounding_supports = [support1, support2] - candidate.grounding_metadata = grounding_meta - - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - result = await agent.get_response([]) - assert len(result.citations) == 2 - assert result.citations[0]["source"] == "https://a.com" - assert result.citations[0]["text"] == "fact one" - assert result.citations[1]["source"] == "https://b.com" - assert result.citations[1]["text"] == "fact two" - - -class TestGeminiCitationInjection: - """Test that enable_citations injects google_search when missing.""" - - @pytest.fixture - def mock_gemini_client(self) -> MagicMock: - client = MagicMock(spec=genai.Client) - client.aio = MagicMock() - client.aio.models = MagicMock() - client.aio.models.generate_content = AsyncMock() - return client - - def _make_agent(self, client: MagicMock) -> GeminiAgent: - agent = GeminiAgent.create(model_client=client, validate_api_key=False) - agent.gemini_tools = [] - agent._gemini_to_mcp_tool_map = {} - agent._initialized = True - return agent - - @pytest.mark.asyncio - async def test_google_search_injected_when_citations_enabled( - self, mock_gemini_client: MagicMock - ) -> None: - """When scenario_enable_citations=True and no google_search tool, inject one.""" - agent = self._make_agent(mock_gemini_client) - ctx = MockEvalContext() - ctx.scenario_enable_citations = True - agent.ctx = ctx - - candidate = MagicMock() - candidate.content = MagicMock() - candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")] - candidate.grounding_metadata = None - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - await agent.get_response([]) - - call_kwargs = mock_gemini_client.aio.models.generate_content.call_args - config = call_kwargs.kwargs["config"] - tools_passed = config.tools - assert any( - isinstance(t, genai_types.Tool) and t.google_search is not None for t in tools_passed - ) - - @pytest.mark.asyncio - async def test_no_duplicate_google_search_when_already_present( - self, mock_gemini_client: MagicMock - ) -> None: - """When google_search tool already exists, don't add a second one.""" - agent = self._make_agent(mock_gemini_client) - existing_search_tool = genai_types.Tool(google_search=genai_types.GoogleSearch()) - agent.gemini_tools = [existing_search_tool] - ctx = MockEvalContext() - ctx.scenario_enable_citations = True - agent.ctx = ctx - - candidate = MagicMock() - candidate.content = MagicMock() - candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")] - candidate.grounding_metadata = None - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - await agent.get_response([]) - - call_kwargs = mock_gemini_client.aio.models.generate_content.call_args - config = call_kwargs.kwargs["config"] - tools_passed = config.tools - search_count = sum( - 1 - for t in tools_passed - if isinstance(t, genai_types.Tool) and t.google_search is not None - ) - assert search_count == 1 - - @pytest.mark.asyncio - async def test_no_injection_when_citations_disabled( - self, mock_gemini_client: MagicMock - ) -> None: - """When scenario_enable_citations=False, no google_search is injected.""" - agent = self._make_agent(mock_gemini_client) - ctx = MockEvalContext() - ctx.scenario_enable_citations = False - agent.ctx = ctx - - candidate = MagicMock() - candidate.content = MagicMock() - candidate.content.parts = [MagicMock(function_call=None, thought=False, text="Hi")] - candidate.grounding_metadata = None - resp = MagicMock() - resp.candidates = [candidate] - mock_gemini_client.aio.models.generate_content = AsyncMock(return_value=resp) - - await agent.get_response([]) - - call_kwargs = mock_gemini_client.aio.models.generate_content.call_args - config = call_kwargs.kwargs["config"] - tools_passed = config.tools - assert not any( - isinstance(t, genai_types.Tool) and t.google_search is not None for t in tools_passed - ) diff --git a/hud/agents/tests/test_hosted_tools.py b/hud/agents/tests/test_hosted_tools.py index deee000f3..ce4d76aea 100644 --- a/hud/agents/tests/test_hosted_tools.py +++ b/hud/agents/tests/test_hosted_tools.py @@ -1,48 +1,137 @@ +"""Provider-hosted tool configuration tests.""" + from __future__ import annotations +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + import pytest +from google.genai import types as genai_types +from openai.types.responses import ResponseOutputMessage, ResponseOutputText -from hud.agents.base import CategorizedTools +from hud.agents.base import AgentContext from hud.agents.claude import ( ClaudeAgent, ClaudeToolSearchTool, ClaudeWebFetchTool, ClaudeWebSearchTool, ) -from hud.agents.gemini import ( - GeminiAgent, - GeminiCodeExecutionTool, - GeminiGoogleSearchTool, - GeminiUrlContextTool, -) -from hud.agents.openai import ( - OpenAIAgent, - OpenAICodeInterpreterTool, - OpenAIToolSearchTool, -) +from hud.agents.gemini import GeminiAgent, GeminiCodeExecutionTool, GeminiGoogleSearchTool +from hud.agents.openai import OpenAIAgent, OpenAICodeInterpreterTool, OpenAIToolSearchTool +from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt -def test_claude_agent_configured_hosted_tools() -> None: - agent = ClaudeAgent.create( - model_client=object(), - hosted_tools=[ - ClaudeWebSearchTool(max_uses=3), - ClaudeWebFetchTool(citations_enabled=True), - ClaudeToolSearchTool(threshold=7), +def _message_response(text: str) -> SimpleNamespace: + return SimpleNamespace( + id="resp", + output=[ + ResponseOutputMessage( + id="msg", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text=text, annotations=[])], + ) ], ) - agent._available_tools = [] - agent._categorized_tools = CategorizedTools() - agent._convert_tools_for_claude() - assert {tool.get("type") for tool in agent.claude_tools if isinstance(tool, dict)} == { - "web_search_20250305", - "web_fetch_20250910", - "tool_search_tool_bm25_20251119", - } - assert agent._required_betas == set() - assert agent._tool_search_threshold == 7 +class Stream: + def __init__(self, text: str) -> None: + block = MagicMock() + block.type = "text" + block.text = text + block.citations = None + self.response = MagicMock() + self.response.content = [block] + + async def __aenter__(self) -> Stream: + return self + + async def __aexit__(self, *args: object) -> bool: + return False + + def __aiter__(self) -> Stream: + return self + + async def __anext__(self) -> None: + raise StopAsyncIteration + + async def get_final_message(self) -> MagicMock: + return self.response + + +def _gemini_response(text: str) -> genai_types.GenerateContentResponse: + return genai_types.GenerateContentResponse( + candidates=[ + genai_types.Candidate( + content=genai_types.Content(role="model", parts=[genai_types.Part(text=text)]) + ) + ] + ) + + +def _gemini_client(response: genai_types.GenerateContentResponse) -> MagicMock: + client = MagicMock() + client.aio = MagicMock() + client.aio.models = MagicMock() + client.aio.models.generate_content = AsyncMock(return_value=response) + return client + + +def test_openai_hosted_tools_are_model_gated() -> None: + tool = OpenAICodeInterpreterTool(container={"type": "auto"}) + + assert tool.supports_model("gpt-5.4") + assert not tool.supports_model("gpt-4.1") + + +@pytest.mark.asyncio +async def test_supported_openai_hosted_tool_is_sent_to_provider() -> None: + client = SimpleNamespace( + responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done"))) + ) + agent = OpenAIAgent.create( + model="gpt-5.4", + model_client=client, + validate_api_key=False, + hosted_tools=[OpenAICodeInterpreterTool(container={"type": "auto"})], + ) + + result = await agent.run( + AgentContext( + messages=[text_prompt("use hosted code")], + tool_client=RecordingToolEnvironment().client, + ) + ) + + assert result.content == "done" + tools = client.responses.create.await_args.kwargs["tools"] + assert any(tool["type"] == "code_interpreter" for tool in tools) + + +@pytest.mark.asyncio +async def test_unsupported_openai_hosted_tool_is_not_sent_to_provider() -> None: + client = SimpleNamespace( + responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done"))) + ) + agent = OpenAIAgent.create( + model="gpt-4.1", + model_client=client, + validate_api_key=False, + hosted_tools=[OpenAICodeInterpreterTool(container={"type": "auto"})], + ) + + result = await agent.run( + AgentContext( + messages=[text_prompt("use hosted code")], + tool_client=RecordingToolEnvironment().client, + ) + ) + + assert result.content == "done" + tools = client.responses.create.await_args.kwargs["tools"] + assert not isinstance(tools, list) def test_claude_hosted_domain_filters_are_mutually_exclusive() -> None: @@ -59,68 +148,126 @@ def test_claude_hosted_domain_filters_are_mutually_exclusive() -> None: ).to_params() -def test_openai_agent_configured_hosted_tools() -> None: +def test_gemini_google_search_rejects_unsupported_dynamic_threshold() -> None: + with pytest.raises(ValueError, match="dynamic_threshold"): + GeminiGoogleSearchTool(dynamic_threshold=0.2).to_params() + + +@pytest.mark.asyncio +async def test_openai_tool_search_threshold_defers_function_loading() -> None: + client = SimpleNamespace( + responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("done"))) + ) agent = OpenAIAgent.create( - model_client=object(), - hosted_tools=[ - OpenAICodeInterpreterTool(container={"type": "auto"}), - OpenAIToolSearchTool(threshold=4), - ], + model="gpt-5.4", + model_client=client, + validate_api_key=False, + hosted_tools=[OpenAIToolSearchTool(threshold=1)], ) - agent._available_tools = [] - agent._categorized_tools = CategorizedTools() + environment = RecordingToolEnvironment([mcp_tool("first"), mcp_tool("second")]) - agent._convert_tools_for_openai() + result = await agent.run( + AgentContext( + messages=[text_prompt("use tools")], + tool_client=environment.client, + ) + ) - assert {"code_interpreter", "tool_search"} <= { - tool.get("type") for tool in agent._openai_tools if isinstance(tool, dict) - } - assert agent._tool_search_threshold == 4 + assert result.content == "done" + tools = client.responses.create.await_args.kwargs["tools"] + function_tools = [tool for tool in tools if tool["type"] == "function"] + assert len(function_tools) == 2 + assert all(tool["defer_loading"] is True for tool in function_tools) -def test_openai_hosted_tools_are_model_gated() -> None: - agent = OpenAIAgent.create( - model_client=object(), - model="gpt-4.1", +@pytest.mark.asyncio +async def test_claude_hosted_web_fetch_payload_is_sent_to_provider() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace(stream=MagicMock(return_value=Stream("done"))) + ) + ) + agent = ClaudeAgent.create( + model="claude-sonnet-4-6", + model_client=client, + validate_api_key=False, hosted_tools=[ - OpenAICodeInterpreterTool(container={"type": "auto"}), - OpenAIToolSearchTool(threshold=4), + ClaudeWebFetchTool( + max_uses=2, + allowed_domains=["example.com"], + max_content_tokens=500, + citations_enabled=True, + ) ], ) - agent._available_tools = [] - agent._categorized_tools = CategorizedTools() - agent._convert_tools_for_openai() + result = await agent.run( + AgentContext( + messages=[text_prompt("fetch")], + tool_client=RecordingToolEnvironment().client, + ) + ) - assert agent._openai_tools == [] - assert agent._tool_search_threshold is None + assert result.content == "done" + tools = client.beta.messages.stream.call_args.kwargs["tools"] + assert tools == [ + { + "type": "web_fetch_20250910", + "name": "web_fetch", + "max_uses": 2, + "allowed_domains": ["example.com"], + "max_content_tokens": 500, + "citations": {"enabled": True}, + } + ] -def test_gemini_agent_configured_hosted_tools() -> None: - agent = GeminiAgent.create( - model_client=object(), - hosted_tools=[ - GeminiGoogleSearchTool(), - GeminiUrlContextTool(), - GeminiCodeExecutionTool(), - ], +@pytest.mark.asyncio +async def test_claude_tool_search_threshold_defers_generic_tools() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace(stream=MagicMock(return_value=Stream("done"))) + ) + ) + agent = ClaudeAgent.create( + model="claude-sonnet-4-6", + model_client=client, + validate_api_key=False, + hosted_tools=[ClaudeToolSearchTool(threshold=1)], ) - agent._available_tools = [] - agent._categorized_tools = CategorizedTools() - agent._convert_tools_for_gemini() + result = await agent.run( + AgentContext( + messages=[text_prompt("use tools")], + tool_client=RecordingToolEnvironment([mcp_tool("first"), mcp_tool("second")]).client, + ) + ) - assert any(getattr(tool, "google_search", None) is not None for tool in agent.gemini_tools) - assert any(getattr(tool, "url_context", None) is not None for tool in agent.gemini_tools) - assert any(getattr(tool, "code_execution", None) is not None for tool in agent.gemini_tools) + assert result.content == "done" + tools = client.beta.messages.stream.call_args.kwargs["tools"] + generic_tools = [tool for tool in tools if "input_schema" in tool] + assert len(generic_tools) == 2 + assert all(tool["defer_loading"] is True for tool in generic_tools) + assert any(tool["type"] == "tool_search_tool_bm25_20251119" for tool in tools) -def test_gemini_google_search_rejects_unsupported_dynamic_threshold() -> None: - tool = GeminiGoogleSearchTool(dynamic_threshold=0.2) - - try: - tool.to_params() - except ValueError as exc: - assert "dynamic_threshold" in str(exc) - else: - raise AssertionError("dynamic_threshold should be rejected") +@pytest.mark.asyncio +async def test_gemini_hosted_code_execution_payload_is_sent_to_provider() -> None: + client = _gemini_client(_gemini_response("done")) + agent = GeminiAgent.create( + model_client=client, + validate_api_key=False, + hosted_tools=[GeminiCodeExecutionTool()], + ) + + result = await agent.run( + AgentContext( + messages=[text_prompt("run code")], + tool_client=RecordingToolEnvironment().client, + ) + ) + + assert result.content == "done" + config = client.aio.models.generate_content.await_args.kwargs["config"] + assert len(config.tools) == 1 + assert config.tools[0].code_execution is not None diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py deleted file mode 100644 index cd8438628..000000000 --- a/hud/agents/tests/test_openai.py +++ /dev/null @@ -1,824 +0,0 @@ -"""Tests for OpenAI MCP Agent implementation.""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from mcp import types -from openai import AsyncOpenAI -from openai.types.responses import ( - ResponseFunctionToolCall, - ResponseOutputMessage, - ResponseOutputText, - ResponseReasoningItem, -) -from openai.types.responses.response_reasoning_item import Summary - -from hud.agents.openai import OpenAIAgent -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.types import MCPToolCall, MCPToolResult - -if TYPE_CHECKING: - from collections.abc import Generator - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing.""" - - def __init__(self, tools: list[types.Tool] | None = None) -> None: - # Core attributes - self.prompt = "Test prompt" - self._tools = tools or [] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - - # Environment attributes - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.scenario_enable_citations: bool = False - self.scenario_returns_schema: dict[str, Any] | None = None - self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} - self.results: list[Any] = [] - self.calls: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return False - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - self.calls.append(call) - return MCPToolResult( - content=[types.TextContent(type="text", text="ok")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class TestOpenAIAgent: - """Test OpenAIAgent class.""" - - @pytest.fixture - def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc] - """Create a stub OpenAI client.""" - with patch("hud.agents.openai.agent.AsyncOpenAI") as mock_class: - client = AsyncOpenAI(api_key="test", base_url="http://localhost") - client.chat.completions.create = AsyncMock() - client.responses.create = AsyncMock() - mock_class.return_value = client - yield client # type: ignore[misc] - - @pytest.mark.asyncio - async def test_init_with_client(self, mock_openai: AsyncOpenAI) -> None: - """Test agent initialization with provided client.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - model="gpt-4o", - validate_api_key=False, - ) - - assert agent.model_name == "OpenAI" - assert agent.config.model == "gpt-4o" - assert agent.model == "gpt-4o" - assert agent.openai_client == mock_openai - assert agent.max_output_tokens is None - assert agent.temperature is None - - @pytest.mark.asyncio - async def test_init_with_parameters(self, mock_openai: AsyncOpenAI) -> None: - """Test agent initialization with various parameters.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - model="gpt-4o", - max_output_tokens=2048, - temperature=0.7, - reasoning={"effort": "high"}, - tool_choice="auto", - parallel_tool_calls=True, - validate_api_key=False, - ) - - assert agent.max_output_tokens == 2048 - assert agent.temperature == 0.7 - assert agent.reasoning == {"effort": "high"} - assert agent.tool_choice == "auto" - assert agent.parallel_tool_calls is True - - @pytest.mark.asyncio - async def test_init_without_client_no_api_key(self) -> None: - """Test agent initialization fails without API key.""" - with patch("hud.agents.openai.agent.settings") as mock_settings: - mock_settings.api_key = None - mock_settings.openai_api_key = None - with pytest.raises(ValueError, match="No API key found"): - OpenAIAgent.create() - - @pytest.mark.asyncio - async def test_format_blocks_text_only(self, mock_openai: AsyncOpenAI) -> None: - """Test formatting text content blocks.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Hello, world!"), - types.TextContent(type="text", text="How are you?"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - assert messages[0]["role"] == "user" - assert len(messages[0]["content"]) == 2 - assert messages[0]["content"][0]["type"] == "input_text" - assert messages[0]["content"][0]["text"] == "Hello, world!" - - @pytest.mark.asyncio - async def test_format_blocks_with_image(self, mock_openai: AsyncOpenAI) -> None: - """Test formatting image content blocks.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Look at this:"), - types.ImageContent(type="image", data="base64data", mimeType="image/png"), - ] - - messages = await agent.format_blocks(blocks) - assert len(messages) == 1 - assert len(messages[0]["content"]) == 2 - assert messages[0]["content"][1]["type"] == "input_image" - assert messages[0]["content"][1]["image_url"] == "data:image/png;base64,base64data" # type: ignore[typeddict-item] - - @pytest.mark.asyncio - async def test_format_blocks_empty(self, mock_openai: AsyncOpenAI) -> None: - """Test formatting empty content blocks.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - messages = await agent.format_blocks([]) - assert len(messages) == 1 - # Empty blocks produce a single empty text item - assert len(messages[0]["content"]) == 1 - assert messages[0]["content"][0]["type"] == "input_text" - assert messages[0]["content"][0]["text"] == "" - - @pytest.mark.asyncio - async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None: - """Test formatting tool results with text content.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Tool output")], - isError=False, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - assert messages[0]["type"] == "function_call_output" - assert messages[0]["call_id"] == "call_123" - # Output is a list of content items - assert len(messages[0]["output"]) == 1 - assert messages[0]["output"][0]["text"] == "Tool output" # type: ignore[index] - - @pytest.mark.asyncio - async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None: - """Test formatting tool results with error.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Error message")], - isError=True, - ) - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - # Output is a list; first item is error indicator, second is the message - msg = cast("dict[str, Any]", messages[0]) - output = cast("list[dict[str, Any]]", msg["output"]) - assert any(item.get("text") == "[tool_error] true" for item in output) - assert any(item.get("text") == "Error message" for item in output) - - @pytest.mark.asyncio - async def test_get_system_messages(self, mock_openai: AsyncOpenAI) -> None: - """Test getting system messages - OpenAI uses instructions field instead.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - system_prompt="You are a helpful assistant.", - validate_api_key=False, - ) - - # OpenAI agent returns empty list - system prompt is passed via instructions - messages = await agent.get_system_messages() - assert len(messages) == 0 - - @pytest.mark.asyncio - async def test_convert_tools_for_openai(self, mock_openai: AsyncOpenAI) -> None: - """Test converting MCP tools to OpenAI format.""" - tools = [ - types.Tool( - name="my_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - # Initialize with context to trigger tool conversion - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Check that tools were converted - assert len(agent._openai_tools) >= 1 - # Find our tool - tool = next((t for t in agent._openai_tools if t.get("name") == "my_tool"), None) - assert tool is not None - assert tool["type"] == "function" - - @pytest.mark.asyncio - async def test_convert_tools_raises_on_incomplete(self, mock_openai: AsyncOpenAI) -> None: - """Test that tools without description raise error.""" - tools = [ - types.Tool( - name="incomplete_tool", - description=None, # Missing description - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - agent.ctx = ctx - with pytest.raises(ValueError, match="requires both a description"): - await agent._initialize_from_ctx(ctx) - - @pytest.mark.asyncio - async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None: - """Test getting response with text output.""" - # Setup mock response - mock_response = AsyncMock() - mock_response.output = [ - ResponseOutputMessage( - id="msg_123", - type="message", - role="assistant", - status="completed", - content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])], - ) - ] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - # Set empty tools to avoid needing initialization - agent._openai_tools = [] - agent._initialized = True - - response = await agent.get_response([]) - assert response.content == "Hello!" - assert response.done is True - assert len(response.tool_calls) == 0 - - @pytest.mark.asyncio - async def test_get_response_with_tool_call(self, mock_openai: AsyncOpenAI) -> None: - """Test getting response with tool call.""" - mock_response = AsyncMock() - # Tool calls come as separate output items, not inside message content - mock_response.output = [ - ResponseFunctionToolCall( - id="call_123", - type="function_call", - call_id="call_123", - name="my_tool", - arguments='{"x": "value"}', - ) - ] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - agent._openai_tools = [] - agent._tool_name_map = {"my_tool": "my_tool"} - agent._initialized = True - - response = await agent.get_response([]) - assert response.done is False - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "my_tool" - assert response.tool_calls[0].arguments == {"x": "value"} - - @pytest.mark.asyncio - async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> None: - """Test getting response with reasoning.""" - mock_response = AsyncMock() - mock_response.output = [ - ResponseReasoningItem( - id="reason_123", - type="reasoning", - summary=[Summary(type="summary_text", text="Thinking about it...")], - ), - ResponseOutputMessage( - id="msg_123", - type="message", - role="assistant", - status="completed", - content=[ResponseOutputText(type="output_text", text="Answer!", annotations=[])], - ), - ] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - agent._openai_tools = [] - agent._initialized = True - - response = await agent.get_response([]) - # Reasoning is stored separately from content - assert response.reasoning == "Thinking about it..." - assert response.content == "Answer!" - - @pytest.mark.asyncio - async def test_get_response_requests_sources_when_citations_enabled( - self, mock_openai: AsyncOpenAI - ) -> None: - """Scenario citation mode should request source payloads from Responses API.""" - mock_response = AsyncMock() - mock_response.id = "resp_123" - mock_response.output = [ - ResponseOutputMessage( - id="msg_123", - type="message", - role="assistant", - status="completed", - content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])], - ) - ] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - agent._openai_tools = [] - agent._initialized = True - - ctx = MockEvalContext() - ctx.scenario_enable_citations = True - agent.ctx = ctx - - await agent.get_response([]) - - call_kwargs = mock_openai.responses.create.await_args.kwargs # type: ignore[union-attr] - assert call_kwargs.get("include") == ["web_search_call.action.sources"] - - -class TestOpenAIToolConversion: - """Tests for tool conversion to OpenAI format.""" - - @pytest.fixture - def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc] - """Create a stub OpenAI client.""" - with patch("hud.agents.openai.agent.AsyncOpenAI") as mock_class: - client = AsyncOpenAI(api_key="test", base_url="http://localhost") - client.responses.create = AsyncMock() - mock_class.return_value = client - yield client # type: ignore[misc] - - @pytest.mark.asyncio - async def test_shell_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: - """Test that the agent converts shell capability to OpenAI native format.""" - tools = [ - types.Tool( - name="bash", - description="Execute shell commands", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - # Check for native shell tool - shell_tool = next((t for t in agent._openai_tools if t.get("type") == "shell"), None) - assert shell_tool == {"type": "shell", "environment": {"type": "local"}} - assert agent._tool_name_map["shell"] == "shell" - assert agent._openai_native_tools["shell"].env_tool_name == "bash" - - @pytest.mark.asyncio - async def test_editor_tool_stays_generic(self, mock_openai: AsyncOpenAI) -> None: - """Editor capabilities are not advertised as OpenAI apply_patch.""" - tools = [ - types.Tool( - name="edit", - description="Apply V4A patches", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert all(t.get("type") != "apply_patch" for t in agent._openai_tools) - assert "apply_patch" not in agent._tool_name_map - assert "apply_patch" not in agent._openai_native_tools - assert [tool.get("type") for tool in agent._openai_tools] == ["function"] - assert agent._openai_tools[0].get("name") == "edit" - - @pytest.mark.asyncio - async def test_capability_metadata_routes_openai_tools(self, mock_openai: AsyncOpenAI) -> None: - """Test env-level capabilities can bind OpenAI tools to non-public names.""" - tools = [ - types.Tool( - name="run_shell", - description="Execute shell commands", - inputSchema={"type": "object"}, - ), - types.Tool( - name="patch_files", - description="Apply V4A patches", - inputSchema={"type": "object"}, - ), - ] - ctx = MockEvalContext(tools=tools) - ctx.metadata["environment_capabilities"] = { - "capabilities": { - "shell": "run_shell", - "editor": {"tool": "patch_files"}, - } - } - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert {t.get("type") for t in agent._openai_tools} == {"shell", "function"} - assert agent._tool_name_map["shell"] == "shell" - assert agent._openai_native_tools["shell"].env_tool_name == "run_shell" - assert "apply_patch" not in agent._tool_name_map - assert "apply_patch" not in agent._openai_native_tools - assert [tool.name for tool in agent._categorized_tools.generic] == [ - "run_shell", - "patch_files", - ] - - @pytest.mark.asyncio - async def test_non_hosted_native_metadata_is_generic(self, mock_openai: AsyncOpenAI) -> None: - """OpenAI ignores env-owned provider metadata.""" - tools = [ - types.Tool( - name="custom_tool", - description="Custom tool", - inputSchema={"type": "object", "properties": {}}, - _meta={ - "native_tools": { - "openai": { - "api_type": "custom_native", - "api_name": "custom_native", - "role": "custom", - } - } - }, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - assert [tool.name for tool in agent._categorized_tools.generic] == ["custom_tool"] - assert {tool.get("type") for tool in agent._openai_tools} == {"function"} - - @pytest.mark.asyncio - async def test_openai_shell_call_routes_directly_to_bash( - self, mock_openai: AsyncOpenAI - ) -> None: - """Test OpenAI shell calls stay provider-owned until execution.""" - tools = [ - types.Tool( - name="bash", - description="Execute shell commands", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False) - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - tool_call = agent._extract_tool_call( - SimpleNamespace( - type="shell_call", - action=SimpleNamespace( - to_dict=lambda: {"commands": ["pwd", "ls"], "timeout_ms": 5000} - ), - call_id="call_1", - ) - ) - - assert tool_call == MCPToolCall( - name="shell", - arguments={"commands": ["pwd", "ls"], "timeout_ms": 5000}, - id="call_1", - ) - - results = await agent.call_tools(tool_call) - assert [(call.name, call.arguments) for call in ctx.calls] == [ - ("bash", {"command": "pwd", "timeout_seconds": 5.0}), - ("bash", {"command": "ls", "timeout_seconds": 5.0}), - ] - assert results[0].structuredContent["provider_tool"] == "shell" # type: ignore[index] - - @pytest.mark.asyncio - async def test_computer_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: - """Test that the agent converts computer capability to OpenAI native format.""" - tools = [ - types.Tool( - name="computer", - description="Control computer", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - agent = OpenAIAgent.create( - model_client=mock_openai, - validate_api_key=False, - ) - - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - computer_tool = next( - (t for t in agent._openai_tools if t.get("type") == "computer"), - None, - ) - assert computer_tool is not None - assert agent._tool_name_map["computer"] == "computer" - assert agent._openai_native_tools["computer"].env_tool_name == "computer" - - @pytest.mark.asyncio - async def test_openai_computer_call_routes_directly_to_generic_computer( - self, mock_openai: AsyncOpenAI - ) -> None: - """Test OpenAI computer calls stay provider-owned until execution.""" - tools = [ - types.Tool( - name="computer", - description="Control computer", - inputSchema={"type": "object"}, - ) - ] - ctx = MockEvalContext(tools=tools) - - async def call_tool(call: Any, /, **kwargs: Any) -> MCPToolResult: - del kwargs - ctx.calls.append(call) - if call.arguments["action"] == "screenshot": - return MCPToolResult( - content=[types.ImageContent(type="image", data="img", mimeType="image/png")], - isError=False, - ) - return MCPToolResult( - content=[types.TextContent(type="text", text="clicked")], - isError=False, - ) - - ctx.call_tool = call_tool # type: ignore[method-assign] - agent = OpenAIAgent.create(model_client=mock_openai, validate_api_key=False) - agent.ctx = ctx - await agent._initialize_from_ctx(ctx) - - tool_call = agent._extract_tool_call( - SimpleNamespace( - type="computer_call", - pending_safety_checks=[], - action=SimpleNamespace( - to_dict=lambda: { - "type": "click", - "x": 10, - "y": 20, - "button": "left", - "keys": ["CTRL"], - } - ), - call_id="call_1", - ) - ) - - assert tool_call is not None - assert tool_call == MCPToolCall( - name="computer", - arguments={"type": "click", "x": 10, "y": 20, "button": "left", "keys": ["CTRL"]}, - id="call_1", - ) - - results = await agent.call_tools(tool_call) - assert [(call.name, call.arguments) for call in ctx.calls] == [ - ( - "computer", - {"action": "click", "x": 10, "y": 20, "button": "left", "hold_keys": ["ctrl"]}, - ), - ("computer", {"action": "screenshot"}), - ] - - messages = await agent.format_tool_results([tool_call], results) - assert messages == [ - { - "type": "computer_call_output", - "call_id": "call_1", - "output": { - "type": "computer_screenshot", - "image_url": "data:image/png;base64,img", - "detail": "original", - }, - } - ] - - -class TestOpenAICitations: - """Tests for OpenAI annotation citation extraction.""" - - @pytest.fixture - def mock_openai(self) -> AsyncOpenAI: - client = AsyncOpenAI(api_key="test", base_url="http://localhost") - client.responses.create = AsyncMock() - return client - - def _make_response(self, output: list[Any]) -> MagicMock: - response = MagicMock() - response.id = "resp_1" - response.output = output - return response - - @pytest.mark.asyncio - async def test_url_citation_extracted(self, mock_openai: AsyncOpenAI) -> None: - """url_citation annotations are extracted as citations.""" - from openai.types.responses.response_output_text import AnnotationURLCitation - - agent = OpenAIAgent.create( - model_client=mock_openai, - model="gpt-4o", - validate_api_key=False, - ) - agent._openai_tools = [] - agent._initialized = True - - ann = AnnotationURLCitation( - type="url_citation", - url="https://example.com/article", - title="Article", - start_index=10, - end_index=25, - ) - text_block = ResponseOutputText(type="output_text", text="Hello world", annotations=[ann]) - msg_item = ResponseOutputMessage( - id="msg_1", - type="message", - role="assistant", - content=[text_block], - status="completed", - ) - mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item])) - - result = await agent.get_response( - [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}] - ) - - assert len(result.citations) == 1 - cit = result.citations[0] - assert cit["type"] == "url_citation" - assert cit["source"] == "https://example.com/article" - assert cit["title"] == "Article" - assert cit["start_index"] == 10 - assert cit["end_index"] == 25 - - @pytest.mark.asyncio - async def test_file_citation_extracted(self, mock_openai: AsyncOpenAI) -> None: - """file_citation annotations are extracted as citations.""" - from openai.types.responses.response_output_text import AnnotationFileCitation - - agent = OpenAIAgent.create( - model_client=mock_openai, - model="gpt-4o", - validate_api_key=False, - ) - agent._openai_tools = [] - agent._initialized = True - - ann = AnnotationFileCitation( - type="file_citation", - file_id="file-abc123", - filename="report.pdf", - index=0, - ) - text_block = ResponseOutputText(type="output_text", text="Facts", annotations=[ann]) - msg_item = ResponseOutputMessage( - id="msg_1", - type="message", - role="assistant", - content=[text_block], - status="completed", - ) - mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item])) - - result = await agent.get_response( - [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}] - ) - - assert len(result.citations) == 1 - cit = result.citations[0] - assert cit["type"] == "file_citation" - assert cit["source"] == "file-abc123" - assert cit["title"] == "report.pdf" - - @pytest.mark.asyncio - async def test_no_annotations_no_citations(self, mock_openai: AsyncOpenAI) -> None: - """No citations when annotations list is empty.""" - agent = OpenAIAgent.create( - model_client=mock_openai, - model="gpt-4o", - validate_api_key=False, - ) - agent._openai_tools = [] - agent._initialized = True - - text_block = ResponseOutputText(type="output_text", text="Plain answer", annotations=[]) - msg_item = ResponseOutputMessage( - id="msg_1", - type="message", - role="assistant", - content=[text_block], - status="completed", - ) - mock_openai.responses.create = AsyncMock(return_value=self._make_response([msg_item])) - - result = await agent.get_response( - [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}] - ) - - assert result.citations == [] diff --git a/hud/agents/tests/test_openai_compatible.py b/hud/agents/tests/test_openai_compatible.py deleted file mode 100644 index 77aaa2d04..000000000 --- a/hud/agents/tests/test_openai_compatible.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, cast - -import mcp.types as types -import pytest - -from hud.agents.openai_compatible import OpenAIChatAgent -from hud.agents.openai_compatible.tools import openai_compatible_tools -from hud.agents.openai_compatible.tools.computer import ( - GLMComputerTool, - QwenComputerTool, - _fix_glm_xml_args, - _parse_glm_box, -) -from hud.agents.openai_compatible.tools.filesystem import ReadTool -from hud.agents.tools import EnvironmentCapability -from hud.types import MCPToolCall, MCPToolResult - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionToolParam - - -def computer_tool(name: str = "computer") -> types.Tool: - return types.Tool( - name=name, - description="Control computer with mouse, keyboard, and screenshots", - inputSchema={ - "type": "object", - "properties": { - "action": {"type": "string"}, - "x": {"type": "integer"}, - "y": {"type": "integer"}, - }, - "required": ["action"], - }, - _meta={"resolution": {"width": 1024, "height": 768}}, - ) - - -def capability(tool: types.Tool) -> EnvironmentCapability: - return EnvironmentCapability(name="computer", tool_name=tool.name, tool=tool) - - -def filesystem_tool(name: str) -> types.Tool: - return types.Tool( - name=name, - description=f"{name} environment tool", - inputSchema={"type": "object", "properties": {}}, - ) - - -def filesystem_capability(tool_name: str = "read") -> EnvironmentCapability: - tool = filesystem_tool(tool_name) - return EnvironmentCapability( - name="filesystem", - tool_name=tool.name, - tool=tool, - metadata={"tools": {"read": "read", "grep": "grep", "glob": "glob", "list": "list"}}, - ) - - -def test_openai_compatible_agent_uses_glm_computer_tool() -> None: - agent = OpenAIChatAgent.create( - model="glm-4.6v", - api_key="test-key", - base_url="http://example.com/v1", - ) - tool = computer_tool() - agent._available_tools = [tool] - agent._categorized_tools = agent.categorize_tools([tool]) - agent._initialized = True - agent._on_tools_ready() - - schemas = agent.get_tool_schemas() - schema = cast("dict[str, Any]", schemas[0]) - - assert schema["type"] == "function" - assert schema["function"]["name"] == "computer" - assert len(schemas) == 1 - assert "computer" in agent._openai_compatible_native_tools - actions = schema["function"]["parameters"]["properties"]["action"]["enum"] - assert "DONE" not in actions - assert "FAIL" not in actions - - -def test_openai_compatible_agent_uses_qwen_computer_tool() -> None: - agent = OpenAIChatAgent.create( - model="qwen2.5-vl", - api_key="test-key", - base_url="http://example.com/v1", - ) - tool = computer_tool() - agent._available_tools = [tool] - agent._categorized_tools = agent.categorize_tools([tool]) - agent._initialized = True - agent._on_tools_ready() - - schemas = agent.get_tool_schemas() - schema = cast("dict[str, Any]", schemas[0]) - - assert schema["type"] == "computer_use" - assert schema["name"] == "computer_use" - assert len(schemas) == 1 - assert "computer_use" in agent._openai_compatible_native_tools - actions = schema["parameters"]["properties"]["action"]["enum"] - assert "terminate" not in actions - assert "answer" not in actions - - -def test_openai_compatible_registry_ignores_legacy_native_metadata() -> None: - tool = types.Tool( - name="glm_computer", - description="legacy GLM computer", - inputSchema={"type": "object", "properties": {}}, - _meta={ - "native_tools": { - "openai_compatible": { - "api_type": "gui_agent_glm45v", - "api_name": "computer", - "role": "computer", - } - } - }, - ) - agent = OpenAIChatAgent.create( - model="glm-4.6v", - api_key="test-key", - base_url="http://example.com/v1", - ) - - categorized = agent.categorize_tools([tool]) - - assert categorized.generic == [tool] - assert categorized.skipped == [] - - -def test_openai_compatible_agent_uses_filesystem_tool_shapes() -> None: - agent = OpenAIChatAgent.create( - model="gpt-4o", - api_key="test-key", - base_url="http://example.com/v1", - ) - tools = [filesystem_tool(name) for name in ("read", "grep", "glob", "list")] - agent._available_tools = tools - agent._categorized_tools = agent.categorize_tools(tools) - agent._initialized = True - agent._on_tools_ready() - - schemas = agent.get_tool_schemas() - function_schemas = [cast("ChatCompletionToolParam", schema) for schema in schemas] - - assert [schema["function"]["name"] for schema in function_schemas] == [ - "read", - "grep", - "glob", - "list", - ] - assert len(schemas) == 4 - assert set(agent._openai_compatible_backing_tools) == {"read", "grep", "glob", "list"} - filesystem = agent._environment_capabilities["filesystem"] - assert filesystem.metadata["tools"] == { - "read": "read", - "grep": "grep", - "glob": "glob", - "list": "list", - } - - -def test_openai_compatible_registry_maps_filesystem_capability_to_read_tool() -> None: - tool = openai_compatible_tools.tool_for_capability( - filesystem_capability(), - "gpt-4o", - ) - - assert isinstance(tool, ReadTool) - assert tool.to_params()["function"]["name"] == "read" - - -def test_parse_glm_box() -> None: - assert _parse_glm_box("[513,438]") == (513, 438) - assert _parse_glm_box("513, 438") == (513, 438) - assert _parse_glm_box([513, 438]) == (513, 438) - assert _parse_glm_box([[513, 438]]) == (513, 438) - assert _parse_glm_box("bad") is None - - -def test_fix_glm_xml_args() -> None: - result = _fix_glm_xml_args( - {"action": "left_click\nstart_box\n[114, 167]"} - ) - - assert result == {"action": "left_click", "start_box": "[114, 167]"} - - -@pytest.mark.asyncio -async def test_glm_computer_translates_to_environment_calls() -> None: - tool = GLMComputerTool.from_capability( - capability(computer_tool()), - GLMComputerTool.default_spec("glm-4.6v"), # type: ignore[arg-type] - "glm-4.6v", - ) - calls: list[MCPToolCall] = [] - - async def caller(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult(content=[], isError=False) - - await tool.execute(caller, {"action": "left_click", "start_box": "[500,300]"}) - - assert calls[0].name == "computer" - assert calls[0].arguments == { - "action": "click", - "x": 512, - "y": 230, - "button": "left", - } - assert calls[1].arguments == {"action": "screenshot"} - - -@pytest.mark.asyncio -async def test_qwen_computer_translates_to_environment_calls() -> None: - tool = QwenComputerTool.from_capability( - capability(computer_tool()), - QwenComputerTool.default_spec("qwen2.5-vl"), # type: ignore[arg-type] - "qwen2.5-vl", - ) - calls: list[MCPToolCall] = [] - - async def caller(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult(content=[], isError=False) - - await tool.execute(caller, {"action": "scroll", "coordinate": [100, 200], "pixels": 50}) - - assert calls[0].name == "computer" - assert calls[0].arguments == { - "action": "scroll", - "x": 100, - "y": 200, - "scroll_y": -50, - } - assert calls[1].arguments == {"action": "screenshot"} - - -@pytest.mark.asyncio -async def test_qwen_left_click_drag_uses_mouse_drag_sequence() -> None: - tool = QwenComputerTool.from_capability( - capability(computer_tool()), - QwenComputerTool.default_spec("qwen2.5-vl"), # type: ignore[arg-type] - "qwen2.5-vl", - ) - calls: list[MCPToolCall] = [] - - async def caller(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult(content=[], isError=False) - - await tool.execute(caller, {"action": "left_click_drag", "coordinate": [300, 400]}) - - assert [call.name for call in calls] == ["computer", "computer", "computer", "computer"] - assert [call.arguments for call in calls] == [ - {"action": "mouse_down", "button": "left"}, - {"action": "move", "x": 300, "y": 400}, - {"action": "mouse_up", "button": "left"}, - {"action": "screenshot"}, - ] - - -@pytest.mark.asyncio -async def test_openai_compatible_filesystem_tool_forwards_to_environment_tool() -> None: - tool = ReadTool.from_capability( - filesystem_capability(), - ReadTool.default_spec("gpt-4o"), - "gpt-4o", - ) - calls: list[MCPToolCall] = [] - - async def caller(call: MCPToolCall) -> MCPToolResult: - calls.append(call) - return MCPToolResult(content=[], isError=False) - - await tool.execute(caller, {"filePath": "/workspace/app.py", "offset": 10, "limit": 5}) - - assert len(calls) == 1 - assert calls[0].name == "read" - assert calls[0].arguments == {"filePath": "/workspace/app.py", "offset": 10, "limit": 5} - - -def test_openai_compatible_tool_registry_selects_model_specific_tool() -> None: - tool = computer_tool() - cap = capability(tool) - - glm_tool = openai_compatible_tools.tool_for_capability(cap, "glm-4.6v") - qwen_tool = openai_compatible_tools.tool_for_capability(cap, "qwen2.5-vl") - unsupported = openai_compatible_tools.tool_for_capability(cap, "llama") - - assert isinstance(glm_tool, GLMComputerTool) - assert isinstance(qwen_tool, QwenComputerTool) - assert unsupported is None diff --git a/hud/agents/tests/test_provider_claude_messages.py b/hud/agents/tests/test_provider_claude_messages.py new file mode 100644 index 000000000..be7fb162b --- /dev/null +++ b/hud/agents/tests/test_provider_claude_messages.py @@ -0,0 +1,257 @@ +"""Claude agent tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest + +from hud.agents.base import AgentContext +from hud.agents.claude import ClaudeAgent +from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result + + +class Stream: + def __init__(self, response: MagicMock) -> None: + self.response = response + + async def __aenter__(self) -> Stream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> bool: + return False + + def __aiter__(self) -> Stream: + return self + + async def __anext__(self) -> None: + raise StopAsyncIteration + + async def get_final_message(self) -> MagicMock: + return self.response + + +class ErrorStream: + def __init__(self, error: Exception) -> None: + self.error = error + + async def __aenter__(self) -> ErrorStream: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> bool: + return False + + def __aiter__(self) -> ErrorStream: + return self + + async def __anext__(self) -> None: + raise self.error + + +def _tool_use(name: str, arguments: dict[str, object]) -> MagicMock: + block = MagicMock() + block.type = "tool_use" + block.id = "call_1" + block.name = name + block.input = arguments + return block + + +def _text_block(text: str, *, thinking: bool = False) -> MagicMock: + block = MagicMock() + block.type = "thinking" if thinking else "text" + block.text = text + block.thinking = text + block.citations = None + return block + + +def _message(*blocks: MagicMock) -> MagicMock: + response = MagicMock() + response.content = list(blocks) + return response + + +@pytest.mark.asyncio +async def test_claude_run_executes_model_tool_call_and_returns_final_answer() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace( + stream=MagicMock( + side_effect=[ + Stream(_message(_tool_use("lookup", {"query": "hud"}))), + Stream(_message(_text_block("final answer"))), + ] + ) + ) + ) + ) + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("tool result")}, + ) + agent = ClaudeAgent.create(model_client=client, validate_api_key=False) + + result = await agent.run( + AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client) + ) + + assert result.content == "final answer" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "hud"}) + ] + assert client.beta.messages.stream.call_count == 2 + second_messages = client.beta.messages.stream.call_args_list[1].kwargs["messages"] + assert second_messages[-1]["role"] == "user" + assert second_messages[-1]["content"][0]["type"] == "tool_result" + + +@pytest.mark.asyncio +async def test_claude_retries_streamed_invalid_tool_json_once() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace( + stream=MagicMock( + side_effect=[ + ErrorStream( + ValueError("Unable to parse tool parameter JSON from model. JSON: {bad") + ), + Stream(_message(_text_block("ok"))), + ] + ) + ) + ) + ) + agent = ClaudeAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response( + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + ) + + assert response.content == "ok" + assert response.done is True + assert client.beta.messages.stream.call_count == 2 + + +@pytest.mark.asyncio +async def test_claude_second_invalid_json_retry_adds_guidance_message() -> None: + invalid_json_error = ValueError("Unable to parse tool parameter JSON from model. JSON: {bad") + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace( + stream=MagicMock( + side_effect=[ + ErrorStream(invalid_json_error), + ErrorStream(invalid_json_error), + Stream(_message(_text_block("ok"))), + ] + ) + ) + ) + ) + agent = ClaudeAgent.create(model_client=client, validate_api_key=False) + messages = [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + + response = await agent.get_response(cast("Any", messages)) + + assert response.content == "ok" + assert client.beta.messages.stream.call_count == 3 + retry_messages = client.beta.messages.stream.call_args_list[2].kwargs["messages"] + retry_text = retry_messages[-1]["content"][0]["text"] + assert "INVALID_JSON" in retry_text + assert "Retry the same intended tool call" in retry_text + + +@pytest.mark.asyncio +async def test_claude_response_preserves_thinking_as_reasoning() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace( + stream=MagicMock( + return_value=Stream( + _message(_text_block("answer"), _text_block("plan", thinking=True)) + ) + ) + ) + ) + ) + agent = ClaudeAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response( + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + ) + + assert response.content == "answer" + assert response.reasoning == "plan" + + +@pytest.mark.asyncio +async def test_claude_extracts_document_citations_from_text_blocks() -> None: + citation = MagicMock() + citation.type = "char_location" + citation.cited_text = "Revenue" + citation.document_index = 0 + citation.document_title = "financials.pdf" + citation.start_char_index = 0 + citation.end_char_index = 7 + text_block = _text_block("Revenue") + text_block.citations = [citation] + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace(stream=MagicMock(return_value=Stream(_message(text_block)))) + ) + ) + agent = ClaudeAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response( + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + ) + + assert response.citations == [ + { + "type": "document_citation", + "text": "Revenue", + "source": "0", + "title": "financials.pdf", + "start_index": 0, + "end_index": 7, + } + ] + + +@pytest.mark.asyncio +async def test_claude_native_computer_requests_required_beta_header() -> None: + client = SimpleNamespace( + beta=SimpleNamespace( + messages=SimpleNamespace( + stream=MagicMock(return_value=Stream(_message(_text_block("answer")))) + ) + ) + ) + agent = ClaudeAgent.create( + model="claude-sonnet-4-6", + model_client=client, + validate_api_key=False, + ) + agent.tools.prepare(model=agent.config.model, tools=[mcp_tool("computer")]) + + response = await agent.get_response( + [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + ) + + assert response.content == "answer" + kwargs = client.beta.messages.stream.call_args.kwargs + assert "computer-use-2025-11-24" in kwargs["betas"] + assert kwargs["tool_choice"] == {"type": "auto", "disable_parallel_tool_use": True} diff --git a/hud/agents/tests/test_provider_computer_tools.py b/hud/agents/tests/test_provider_computer_tools.py new file mode 100644 index 000000000..5504382e6 --- /dev/null +++ b/hud/agents/tests/test_provider_computer_tools.py @@ -0,0 +1,226 @@ +"""Computer tool contracts shared across provider adapters.""" + +from __future__ import annotations + +from typing import Any, cast + +import pytest +from mcp import types + +from hud.agents.gemini.tools.computer import ( + GEMINI_COMPUTER_SPEC, + GEMINI_SAFETY_BLOCKED_PREFIX, + GEMINI_URL_PREFIX, + GeminiComputerTool, +) +from hud.agents.openai.tools.computer import OpenAIComputerTool +from hud.agents.openai_compatible.tools.glm_computer import GLM_COMPUTER_SPEC, GLMComputerTool +from hud.agents.openai_compatible.tools.qwen_computer import ( + QWEN_COMPUTER_SPEC, + QwenComputerTool, +) +from hud.agents.tests.conftest import RecordingToolEnvironment, text_result +from hud.agents.tools.computer import execute_computer_calls +from hud.types import MCPToolCall, MCPToolResult + + +def _image_result(data: str = "screenshot") -> MCPToolResult: + return MCPToolResult( + content=[types.ImageContent(type="image", data=data, mimeType="image/png")], + isError=False, + ) + + +@pytest.mark.asyncio +async def test_shared_computer_execution_appends_screenshot_when_required() -> None: + calls: list[MCPToolCall] = [] + + async def call_tool(call: MCPToolCall) -> MCPToolResult: + calls.append(call) + if (call.arguments or {}).get("action") == "screenshot": + return _image_result("after") + return text_result("clicked") + + result = await execute_computer_calls( + call_tool, + env_tool_name="computer", + calls=[{"action": "click", "x": 1, "y": 2}], + ensure_screenshot=True, + ) + + assert [(call.name, call.arguments) for call in calls] == [ + ("computer", {"action": "click", "x": 1, "y": 2}), + ("computer", {"action": "screenshot"}), + ] + assert [type(block).__name__ for block in result.content] == ["TextContent", "ImageContent"] + + +@pytest.mark.asyncio +async def test_openai_computer_translates_actions_and_requires_final_screenshot() -> None: + spec = OpenAIComputerTool.default_spec("gpt-5.4") + assert spec is not None + tool = OpenAIComputerTool(env_tool_name="computer", spec=spec) + calls: list[MCPToolCall] = [] + + async def call_tool(call: MCPToolCall) -> MCPToolResult: + calls.append(call) + if (call.arguments or {}).get("action") == "screenshot": + return _image_result("after") + return text_result("acted") + + result = await tool.execute( + call_tool, + {"type": "click", "x": 10, "y": 20, "button": "wheel", "keys": ["ctrl"]}, + ) + + assert result.content == [ + types.TextContent(type="text", text="acted"), + types.ImageContent(type="image", data="after", mimeType="image/png"), + ] + assert [(call.name, call.arguments) for call in calls] == [ + ( + "computer", + { + "action": "click", + "x": 10, + "y": 20, + "button": "middle", + "hold_keys": ["ctrl"], + }, + ), + ("computer", {"action": "screenshot"}), + ] + + +def test_openai_computer_formats_screenshot_for_provider_continuation() -> None: + spec = OpenAIComputerTool.default_spec("gpt-5.4") + assert spec is not None + tool = OpenAIComputerTool(env_tool_name="computer", spec=spec) + + formatted = tool.format_result( + MCPToolCall(name="computer", id="call_1", arguments={}), + _image_result("after"), + ) + + output = cast("dict[str, Any]", formatted) + assert output["type"] == "computer_call_output" + assert output["call_id"] == "call_1" + assert output["output"] == { + "type": "computer_screenshot", + "image_url": "data:image/png;base64,after", + "detail": "original", + } + + +def test_openai_computer_rejects_provider_continuation_without_screenshot() -> None: + spec = OpenAIComputerTool.default_spec("gpt-5.4") + assert spec is not None + tool = OpenAIComputerTool(env_tool_name="computer", spec=spec) + + with pytest.raises(ValueError, match="missing screenshot"): + tool.format_result( + MCPToolCall(name="computer", id="call_1", arguments={}), + text_result("no screenshot"), + ) + + +@pytest.mark.asyncio +async def test_gemini_computer_blocks_unconfirmed_safety_decision_without_environment_call() -> ( + None +): + tool = GeminiComputerTool(env_tool_name="computer", spec=GEMINI_COMPUTER_SPEC) + environment = RecordingToolEnvironment() + + result = await tool.execute( + environment.call_tool, + { + "action": "click_at", + "safety_decision": {"decision": "require_confirmation"}, + }, + ) + + assert environment.calls == [] + assert result.isError is False + assert result.content == [ + types.TextContent( + type="text", + text=( + f"{GEMINI_SAFETY_BLOCKED_PREFIX}" + "Gemini Computer Use action requires user confirmation before execution." + ), + ) + ] + + +def test_gemini_computer_formats_url_safety_and_inline_screenshot_parts() -> None: + tool = GeminiComputerTool(env_tool_name="computer", spec=GEMINI_COMPUTER_SPEC) + + content = tool.format_result( + MCPToolCall( + name="computer_use", + provider_name="click_at", + arguments={"safety_decision": {"decision": "allow"}}, + ), + MCPToolResult( + content=[ + types.TextContent(type="text", text="clicked"), + types.TextContent(type="text", text=f"{GEMINI_URL_PREFIX}https://example.com"), + types.ImageContent(type="image", data="YWJj", mimeType="image/png"), + ], + isError=False, + ), + ) + + parts = content.parts or [] + response = parts[0].function_response + assert response is not None + assert response.name == "click_at" + assert response.response == { + "success": True, + "output": "clicked", + "url": "https://example.com", + "safety_acknowledgement": True, + } + response_parts = response.parts or [] + assert response_parts[0].inline_data is not None + assert response_parts[0].inline_data.data == b"abc" + + +@pytest.mark.asyncio +async def test_glm_computer_scales_normalized_click_coordinates() -> None: + tool = GLMComputerTool( + env_tool_name="computer", + spec=GLM_COMPUTER_SPEC, + display_width=1000, + display_height=500, + coordinate_space=None, + ) + environment = RecordingToolEnvironment(results={"computer": text_result("ok")}) + + await tool.execute( + environment.call_tool, + {"action": "left_click", "start_box": "[999,999]"}, + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("computer", {"action": "click", "x": 999, "y": 499, "button": "left"}), + ("computer", {"action": "screenshot"}), + ] + + +@pytest.mark.asyncio +async def test_qwen_computer_translates_wait_seconds_to_milliseconds() -> None: + tool = QwenComputerTool( + env_tool_name="computer", + spec=QWEN_COMPUTER_SPEC, + display_width=1000, + display_height=500, + description="computer", + ) + environment = RecordingToolEnvironment(results={"computer": text_result("waited")}) + + await tool.execute(environment.call_tool, {"action": "wait", "time": 1.5}) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("computer", {"action": "wait", "time": 1500}) + ] diff --git a/hud/agents/tests/test_provider_gemini_generate_content.py b/hud/agents/tests/test_provider_gemini_generate_content.py new file mode 100644 index 000000000..524072625 --- /dev/null +++ b/hud/agents/tests/test_provider_gemini_generate_content.py @@ -0,0 +1,154 @@ +"""Gemini agent tests.""" + +from __future__ import annotations + +from typing import cast +from unittest.mock import AsyncMock, MagicMock + +import pytest +from google.genai import types as genai_types + +from hud.agents.base import AgentContext +from hud.agents.gemini import GeminiAgent +from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result + + +def _gemini_response(*parts: genai_types.Part) -> genai_types.GenerateContentResponse: + return genai_types.GenerateContentResponse( + candidates=[ + genai_types.Candidate( + content=genai_types.Content( + role="model", + parts=list(parts), + ) + ) + ] + ) + + +def _gemini_client(*responses: genai_types.GenerateContentResponse) -> MagicMock: + client = MagicMock() + client.aio = MagicMock() + client.aio.models = MagicMock() + client.aio.models.generate_content = AsyncMock(side_effect=list(responses)) + return client + + +@pytest.mark.asyncio +async def test_gemini_run_executes_model_tool_call_and_returns_final_answer() -> None: + client = _gemini_client( + _gemini_response( + genai_types.Part( + function_call=genai_types.FunctionCall( + name="lookup", + args={"query": "hud"}, + ) + ) + ), + _gemini_response(genai_types.Part(text="final answer")), + ) + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("tool result")}, + ) + agent = GeminiAgent.create(model_client=client, validate_api_key=False) + + result = await agent.run( + AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client) + ) + + assert result.content == "final answer" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "hud"}) + ] + assert client.aio.models.generate_content.await_count == 2 + second_contents = cast( + "list[genai_types.Content]", + client.aio.models.generate_content.await_args_list[1].kwargs["contents"], + ) + function_response_names: list[str] = [] + for content in second_contents: + for part in content.parts or []: + function_response = part.function_response + if function_response is not None: + function_response_names.append(function_response.name or "") + assert "lookup" in function_response_names + + +@pytest.mark.asyncio +async def test_gemini_no_candidates_is_a_user_visible_error() -> None: + client = _gemini_client(genai_types.GenerateContentResponse(candidates=[])) + agent = GeminiAgent.create(model_client=client, validate_api_key=False) + + with pytest.raises(RuntimeError, match="returned no candidates"): + await agent.get_response([]) + + +@pytest.mark.asyncio +async def test_gemini_citations_enable_google_search_at_provider_boundary() -> None: + client = _gemini_client(_gemini_response(genai_types.Part(text="answer"))) + agent = GeminiAgent.create(model_client=client, validate_api_key=False) + agent.enable_citations = True + + response = await agent.get_response([]) + + assert response.content == "answer" + config = client.aio.models.generate_content.await_args.kwargs["config"] + assert any(tool.google_search is not None for tool in config.tools) + + +@pytest.mark.asyncio +async def test_gemini_preserves_thought_parts_as_reasoning() -> None: + client = _gemini_client( + _gemini_response( + genai_types.Part(text="private reasoning", thought=True), + genai_types.Part(text="answer"), + ) + ) + agent = GeminiAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response([]) + + assert response.content == "answer" + assert response.reasoning == "private reasoning" + + +@pytest.mark.asyncio +async def test_gemini_prunes_older_computer_screenshots_before_request() -> None: + def computer_response(name: str) -> genai_types.FunctionResponse: + return genai_types.FunctionResponse( + name=name, + response={"success": True}, + parts=[ + genai_types.FunctionResponsePart( + inline_data=genai_types.FunctionResponseBlob( + mime_type="image/png", + data=b"image-bytes", + ) + ) + ], + ) + + old_response = computer_response("click_at") + recent_response = computer_response("navigate") + messages = [ + genai_types.Content( + role="user", + parts=[genai_types.Part(function_response=old_response)], + ), + genai_types.Content( + role="user", + parts=[genai_types.Part(function_response=recent_response)], + ), + ] + client = _gemini_client(_gemini_response(genai_types.Part(text="answer"))) + agent = GeminiAgent.create(model_client=client, validate_api_key=False) + agent.max_recent_turn_with_screenshots = 1 + + response = await agent.get_response(messages) + + assert response.content == "answer" + assert old_response.parts is None + assert recent_response.parts is not None + requested_contents = client.aio.models.generate_content.await_args.kwargs["contents"] + assert requested_contents is messages diff --git a/hud/agents/tests/test_provider_native_tools.py b/hud/agents/tests/test_provider_native_tools.py new file mode 100644 index 000000000..866b66851 --- /dev/null +++ b/hud/agents/tests/test_provider_native_tools.py @@ -0,0 +1,147 @@ +"""Native provider tool contracts for translation and model gating.""" + +from __future__ import annotations + +import hashlib +from typing import Any, cast + +import pytest + +from hud.agents.claude.tools.coding import ClaudeBashTool, ClaudeTextEditorTool +from hud.agents.gemini.tools.coding import GeminiShellTool +from hud.agents.gemini.tools.filesystem import GeminiReadTool +from hud.agents.gemini.tools.memory import GeminiMemoryTool +from hud.agents.openai.tools.coding import OpenAIShellTool +from hud.agents.tests.conftest import RecordingToolEnvironment, text_result +from hud.types import MCPToolCall + + +@pytest.mark.asyncio +async def test_openai_shell_translates_commands_timeout_and_structured_output() -> None: + spec = OpenAIShellTool.default_spec("gpt-5.4") + assert spec is not None + tool = OpenAIShellTool(env_tool_name="bash", spec=spec) + environment = RecordingToolEnvironment( + results={ + "bash": text_result("pwd output"), + }, + ) + + result = await tool.execute( + environment.call_tool, + {"commands": ["pwd"], "timeout_ms": 2500, "max_output_length": 80}, + ) + formatted = tool.format_result(MCPToolCall(name="shell", id="call_1", arguments={}), result) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("bash", {"command": "pwd", "timeout_seconds": 2.5}) + ] + assert result.structuredContent == { + "provider_tool": "shell", + "output": [ + {"stdout": "pwd output", "stderr": "", "outcome": {"type": "exit", "exit_code": 0}} + ], + "max_output_length": 80, + } + formatted_dict = cast("dict[str, Any]", formatted) + assert formatted_dict["type"] == "shell_call_output" + assert formatted_dict["call_id"] == "call_1" + assert formatted_dict["max_output_length"] == 80 + + +@pytest.mark.asyncio +async def test_openai_shell_rejects_invalid_commands_without_environment_call() -> None: + spec = OpenAIShellTool.default_spec("gpt-5.4") + assert spec is not None + tool = OpenAIShellTool(env_tool_name="bash", spec=spec) + environment = RecordingToolEnvironment() + + result = await tool.execute(environment.call_tool, {"commands": 123}) + + assert result.isError is True + assert environment.calls == [] + + +@pytest.mark.asyncio +async def test_claude_text_editor_translates_str_replace_arguments() -> None: + spec = ClaudeTextEditorTool.default_spec("claude-sonnet-4-6") + assert spec is not None + tool = ClaudeTextEditorTool(env_tool_name="edit", spec=spec) + environment = RecordingToolEnvironment(results={"edit": text_result("edited")}) + + result = await tool.execute( + environment.call_tool, + { + "command": "str_replace", + "path": "/tmp/file.txt", + "old_str": "old", + "new_str": "new", + }, + ) + + assert result.isError is False + assert [(call.name, call.arguments) for call in environment.calls] == [ + ( + "edit", + { + "command": "replace", + "path": "/tmp/file.txt", + "old_text": "old", + "new_text": "new", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_gemini_shell_scopes_command_to_directory() -> None: + tool = GeminiShellTool(env_tool_name="bash", spec=GeminiShellTool.default_spec("gemini")) + environment = RecordingToolEnvironment(results={"bash": text_result("ok")}) + + await tool.execute(environment.call_tool, {"command": "ls -la", "dir_path": "/tmp/my dir"}) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("bash", {"command": "cd '/tmp/my dir' && ls -la"}) + ] + + +@pytest.mark.asyncio +async def test_gemini_read_translates_line_range_to_offset_and_limit() -> None: + tool = GeminiReadTool(env_tool_name="read", spec=GeminiReadTool.default_spec("gemini")) + environment = RecordingToolEnvironment(results={"read": text_result("lines")}) + + await tool.execute( + environment.call_tool, + {"file_path": "/repo/file.py", "start_line": 3, "end_line": 7}, + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("read", {"filePath": "/repo/file.py", "offset": 2, "limit": 5}) + ] + + +@pytest.mark.asyncio +async def test_gemini_memory_persists_trimmed_fact_under_stable_path() -> None: + tool = GeminiMemoryTool(env_tool_name="edit", spec=GeminiMemoryTool.default_spec("gemini")) + environment = RecordingToolEnvironment(results={"edit": text_result("saved")}) + + await tool.execute(environment.call_tool, {"fact": " user likes concise tests "}) + + digest = hashlib.sha256(b"user likes concise tests").hexdigest()[:12] + assert [(call.name, call.arguments) for call in environment.calls] == [ + ( + "edit", + { + "command": "create", + "path": f"/memories/gemini-{digest}.md", + "file_text": "user likes concise tests\n", + }, + ) + ] + + +def test_native_tool_model_gating_uses_provider_supported_model_contracts() -> None: + assert OpenAIShellTool.default_spec("gpt-5.4") is not None + assert OpenAIShellTool.default_spec("gpt-4.1") is None + assert ClaudeBashTool.default_spec("claude-sonnet-4-6") is not None + assert ClaudeBashTool.default_spec("claude-3-5-sonnet") is None diff --git a/hud/agents/tests/test_provider_openai_compatible_chat.py b/hud/agents/tests/test_provider_openai_compatible_chat.py new file mode 100644 index 000000000..373c7b4db --- /dev/null +++ b/hud/agents/tests/test_provider_openai_compatible_chat.py @@ -0,0 +1,215 @@ +"""OpenAI-compatible chat agent tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest +from openai.types.chat.chat_completion import ChatCompletion + +from hud.agents.base import AgentContext +from hud.agents.openai_compatible import OpenAIChatAgent +from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result + + +def _chat_completion(message: dict[str, Any], *, finish_reason: str = "stop") -> ChatCompletion: + return ChatCompletion.model_validate( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "test-model", + "choices": [ + { + "index": 0, + "finish_reason": finish_reason, + "message": message, + } + ], + } + ) + + +def _client(*responses: ChatCompletion) -> SimpleNamespace: + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace(create=AsyncMock(side_effect=list(responses))) + ) + ) + + +def _chat_completion_with_token_ids( + message: dict[str, Any], + *, + prompt_token_ids: list[int], + token_ids: list[int], +) -> ChatCompletion: + completion = _chat_completion(message) + choice = completion.choices[0] + object.__setattr__(choice, "prompt_token_ids", prompt_token_ids) + object.__setattr__(choice, "token_ids", token_ids) + return completion + + +@pytest.mark.asyncio +async def test_openai_compatible_run_executes_model_tool_call_and_returns_final_answer() -> None: + client = _client( + _chat_completion( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "lookup", + "arguments": '{"query":"hud"}', + }, + } + ], + }, + finish_reason="tool_calls", + ), + _chat_completion({"role": "assistant", "content": "final answer"}), + ) + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("tool result")}, + ) + agent = OpenAIChatAgent.create(model="test-model", openai_client=client) + + result = await agent.run( + AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client) + ) + + assert result.content == "final answer" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "hud"}) + ] + assert client.chat.completions.create.await_count == 2 + second_messages = client.chat.completions.create.await_args_list[1].kwargs["messages"] + assert { + "role": "tool", + "tool_call_id": "call_1", + "content": "tool result", + } in second_messages + + +@pytest.mark.asyncio +async def test_openai_compatible_auto_respond_followup_does_not_repeat_system_prompt( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def continue_once(content: str | None, *, enabled: bool) -> object: + assert enabled is True + if content == "need input": + return text_prompt("continue") + return None + + monkeypatch.setattr("hud.agents.base.auto_respond", continue_once) + client = _client( + _chat_completion({"role": "assistant", "content": "need input"}), + _chat_completion({"role": "assistant", "content": "final answer"}), + ) + agent = OpenAIChatAgent.create( + model="test-model", + openai_client=client, + system_prompt="system rules", + auto_respond=True, + ) + + result = await agent.run(AgentContext(messages=[text_prompt("start")])) + + assert result.content == "final answer" + second_messages = client.chat.completions.create.await_args_list[1].kwargs["messages"] + system_messages = [message for message in second_messages if message["role"] == "system"] + assert system_messages == [{"role": "system", "content": "system rules"}] + + +@pytest.mark.asyncio +async def test_openai_compatible_preserves_reasoning_fields_on_assistant_message() -> None: + reasoning_details = [{"type": "reasoning.text", "text": "step"}] + client = _client( + _chat_completion( + { + "role": "assistant", + "content": "answer", + "reasoning": "private reasoning", + "reasoning_details": reasoning_details, + } + ) + ) + agent = OpenAIChatAgent.create(model="reasoning-model", openai_client=client) + messages: list[dict[str, Any]] = [{"role": "user", "content": "question"}] + + result = await agent.get_response(cast("Any", messages)) + + assert result.content == "answer" + assert result.reasoning == "private reasoning" + assert messages[-1]["reasoning"] == "private reasoning" + assert messages[-1]["reasoning_details"] == reasoning_details + + +@pytest.mark.asyncio +async def test_openai_compatible_api_error_returns_error_response() -> None: + client = SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace(create=AsyncMock(side_effect=RuntimeError("boom"))) + ) + ) + agent = OpenAIChatAgent.create(model="test-model", openai_client=client) + + response = await agent.get_response(cast("Any", [{"role": "user", "content": "question"}])) + + assert response.done is True + assert response.isError is True + assert response.content == "Error getting response boom" + + +@pytest.mark.asyncio +async def test_openai_compatible_checkpoint_is_sent_in_provider_body() -> None: + client = _client(_chat_completion({"role": "assistant", "content": "answer"})) + agent = OpenAIChatAgent.create( + model="test-model", + openai_client=client, + checkpoint="checkpoint-123", + ) + + response = await agent.get_response(cast("Any", [{"role": "user", "content": "question"}])) + + assert response.content == "answer" + assert client.chat.completions.create.await_args.kwargs["extra_body"] == { + "checkpoint": "checkpoint-123" + } + + +@pytest.mark.asyncio +async def test_openai_compatible_token_continuation_is_sent_after_first_response() -> None: + client = _client( + _chat_completion_with_token_ids( + {"role": "assistant", "content": "first"}, + prompt_token_ids=[1, 2], + token_ids=[3], + ), + _chat_completion({"role": "assistant", "content": "second"}), + ) + agent = OpenAIChatAgent.create( + model="test-model", + openai_client=client, + completion_kwargs={"extra_body": {"return_token_ids": True}}, + ) + messages = cast("Any", [{"role": "user", "content": "question"}]) + + first = await agent.get_response(messages) + second = await agent.get_response(messages) + + assert first.content == "first" + assert second.content == "second" + second_body = client.chat.completions.create.await_args_list[1].kwargs["extra_body"] + assert second_body == { + "return_token_ids": True, + "prompt_token_ids": [1, 2, 3], + "continuation_from": 2, + } diff --git a/hud/agents/tests/test_provider_openai_responses.py b/hud/agents/tests/test_provider_openai_responses.py new file mode 100644 index 000000000..5cd82108f --- /dev/null +++ b/hud/agents/tests/test_provider_openai_responses.py @@ -0,0 +1,206 @@ +"""OpenAI Responses agent tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) +from openai.types.responses.response_reasoning_item import Summary + +from hud.agents.base import AgentContext +from hud.agents.openai import OpenAIAgent +from hud.agents.tests.conftest import RecordingToolEnvironment, mcp_tool, text_prompt, text_result + + +def _message_response(text: str, *, response_id: str = "resp_final") -> SimpleNamespace: + return SimpleNamespace( + id=response_id, + output=[ + ResponseOutputMessage( + id=f"msg_{response_id}", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text=text, annotations=[])], + ) + ], + ) + + +@pytest.mark.asyncio +async def test_openai_run_executes_model_tool_call_and_returns_final_answer() -> None: + client = SimpleNamespace( + responses=SimpleNamespace( + create=AsyncMock( + side_effect=[ + SimpleNamespace( + id="resp_tool", + output=[ + ResponseFunctionToolCall( + id="item_1", + type="function_call", + call_id="call_1", + name="lookup", + arguments='{"query":"hud"}', + ) + ], + ), + _message_response("final answer"), + ] + ) + ) + ) + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("tool result")}, + ) + agent = OpenAIAgent.create(model_client=client, validate_api_key=False) + + result = await agent.run( + AgentContext(messages=[text_prompt("answer with lookup")], tool_client=environment.client) + ) + + assert result.content == "final answer" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "hud"}) + ] + assert client.responses.create.await_count == 2 + second_input = client.responses.create.await_args_list[1].kwargs["input"] + assert client.responses.create.await_args_list[1].kwargs["previous_response_id"] == "resp_tool" + assert second_input[-1]["type"] == "function_call_output" + assert second_input[-1]["call_id"] == "call_1" + + +@pytest.mark.asyncio +async def test_openai_get_response_preserves_reasoning_and_citations() -> None: + text = ResponseOutputText.model_validate( + { + "type": "output_text", + "text": "Example", + "annotations": [ + { + "type": "url_citation", + "url": "https://example.com", + "title": "Example", + "start_index": 0, + "end_index": 7, + } + ], + } + ) + client = SimpleNamespace( + responses=SimpleNamespace( + create=AsyncMock( + return_value=SimpleNamespace( + id="resp", + output=[ + ResponseReasoningItem( + id="reason", + type="reasoning", + summary=[Summary(type="summary_text", text="thought")], + ), + ResponseOutputMessage( + id="msg", + type="message", + role="assistant", + status="completed", + content=[text], + ), + ], + ) + ) + ) + ) + agent = OpenAIAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response([]) + + assert response.content == "Example" + assert response.reasoning == "thought" + assert response.citations == [ + { + "type": "url_citation", + "text": "Example", + "source": "https://example.com", + "title": "Example", + "start_index": 0, + "end_index": 7, + } + ] + + +@pytest.mark.asyncio +async def test_openai_citation_mode_requests_provider_source_metadata() -> None: + client = SimpleNamespace( + responses=SimpleNamespace(create=AsyncMock(return_value=_message_response("answer"))) + ) + agent = OpenAIAgent.create(model_client=client, validate_api_key=False) + agent.enable_citations = True + + response = await agent.get_response([]) + + assert response.content == "answer" + assert client.responses.create.await_args.kwargs["include"] == [ + "web_search_call.action.sources" + ] + + +@pytest.mark.asyncio +async def test_openai_get_response_parses_native_computer_and_shell_calls() -> None: + def _action(payload: dict[str, Any]) -> SimpleNamespace: + return SimpleNamespace(to_dict=lambda: payload) + + client = SimpleNamespace( + responses=SimpleNamespace( + create=AsyncMock( + return_value=SimpleNamespace( + id="resp", + output=[ + SimpleNamespace( + type="computer_call", + call_id="computer_call_1", + actions=[_action({"type": "click", "x": 1, "y": 2})], + action=None, + pending_safety_checks=[], + ), + SimpleNamespace( + type="shell_call", + call_id="shell_call_1", + action=_action({"commands": ["pwd"]}), + ), + ], + ) + ) + ) + ) + agent = OpenAIAgent.create(model_client=client, validate_api_key=False) + + response = await agent.get_response([]) + + assert response.done is False + assert [(call.name, call.arguments, call.id) for call in response.tool_calls] == [ + ("computer", {"actions": [{"type": "click", "x": 1, "y": 2}]}, "computer_call_1"), + ("shell", {"commands": ["pwd"]}, "shell_call_1"), + ] + + +@pytest.mark.asyncio +async def test_openai_run_returns_error_trace_for_provider_failure() -> None: + client = SimpleNamespace( + responses=SimpleNamespace(create=AsyncMock(side_effect=RuntimeError("provider down"))) + ) + agent = OpenAIAgent.create(model_client=client, validate_api_key=False) + + result = await agent.run(AgentContext(messages=[text_prompt("hello")])) + + assert result.isError is True + assert result.content == "provider down" + assert result.info["error"] == "provider down" diff --git a/hud/agents/tests/test_provider_tool_results.py b/hud/agents/tests/test_provider_tool_results.py new file mode 100644 index 000000000..8ae5f1974 --- /dev/null +++ b/hud/agents/tests/test_provider_tool_results.py @@ -0,0 +1,174 @@ +"""Provider continuation contracts for environment tool results.""" + +from __future__ import annotations + +from typing import Any, cast + +from mcp import types + +from hud.agents.claude.tools.base import ClaudeFunctionTool +from hud.agents.gemini.tools.base import GeminiFunctionTool +from hud.agents.openai.tools.base import OpenAIFunctionTool +from hud.agents.openai_compatible.tools.base import OpenAICompatibleFunctionTool +from hud.agents.tests.conftest import mcp_tool +from hud.types import MCPToolCall, MCPToolResult + + +def _text_image_result() -> MCPToolResult: + return MCPToolResult( + content=[ + types.TextContent(type="text", text="text output"), + types.ImageContent(type="image", data="image-bytes", mimeType="image/png"), + ], + isError=False, + ) + + +def test_openai_formats_text_image_structured_and_error_results() -> None: + tool = OpenAIFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + assert tool is not None + + output = tool.format_result( + MCPToolCall(name="lookup", id="call_1", arguments={}), + MCPToolResult( + content=[ + types.TextContent(type="text", text="failed"), + types.ImageContent(type="image", data="image-bytes", mimeType="image/png"), + ], + isError=True, + structuredContent={"code": 500}, + ), + ) + + assert output is not None + output_dict = cast("dict[str, Any]", output) + assert output_dict["type"] == "function_call_output" + assert output_dict["call_id"] == "call_1" + blocks = cast("list[dict[str, Any]]", output_dict["output"]) + assert {"type": "input_text", "text": "[tool_error] true"} in blocks + assert {"type": "input_text", "text": '{"code": 500}'} in blocks + assert {"type": "input_text", "text": "failed"} in blocks + assert { + "type": "input_image", + "image_url": "data:image/png;base64,image-bytes", + } in blocks + + +def test_openai_formats_empty_result_as_empty_function_output() -> None: + tool = OpenAIFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + assert tool is not None + + output = tool.format_result( + MCPToolCall(name="lookup", id="call_1", arguments={}), + MCPToolResult(content=[], isError=False), + ) + + assert output is not None + blocks = cast("list[dict[str, Any]]", cast("dict[str, Any]", output)["output"]) + assert blocks == [{"type": "input_text", "text": ""}] + + +def test_claude_formats_result_blocks_and_citation_documents() -> None: + tool = ClaudeFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + + message = tool.format_result( + MCPToolCall( + name="lookup", + id="call_1", + arguments={}, + _meta=types.RequestParams.Meta.model_validate({"enable_citations": True}), + ), + _text_image_result(), + ) + + assert message is not None + assert message["role"] == "user" + content = cast("list[dict[str, Any]]", message["content"]) + tool_result = content[0] + assert tool_result["type"] == "tool_result" + assert tool_result["tool_use_id"] == "call_1" + assert cast("list[dict[str, Any]]", tool_result["content"]) == [ + {"type": "text", "text": "text output"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "image-bytes", + }, + }, + ] + assert content[1]["type"] == "document" + assert content[1]["citations"] == {"enabled": True} + + +def test_claude_formats_errors_as_tool_result_text() -> None: + tool = ClaudeFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + + message = tool.format_result( + MCPToolCall(name="lookup", id="call_1", arguments={}), + MCPToolResult( + content=[types.TextContent(type="text", text="boom")], + isError=True, + ), + ) + + assert message is not None + tool_result = cast("list[dict[str, Any]]", message["content"])[0] + assert tool_result["content"] == [{"type": "text", "text": "Error: boom"}] + + +def test_gemini_formats_success_and_error_function_responses() -> None: + tool = GeminiFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + + success = tool.format_result( + MCPToolCall(name="lookup", provider_name="provider_lookup", arguments={}), + MCPToolResult( + content=[types.TextContent(type="text", text="found")], + isError=False, + ), + ) + error = tool.format_result( + MCPToolCall(name="lookup", arguments={}), + MCPToolResult( + content=[types.TextContent(type="text", text="failed")], + isError=True, + ), + ) + + success_parts = success.parts or [] + error_parts = error.parts or [] + success_response = success_parts[0].function_response + error_response = error_parts[0].function_response + assert success_response is not None + assert success_response.name == "provider_lookup" + assert success_response.response == {"success": True, "output": "found"} + assert error_response is not None + assert error_response.response == {"error": "failed"} + + +def test_openai_compatible_formats_text_image_and_structured_results() -> None: + tool = OpenAICompatibleFunctionTool.from_tool(mcp_tool("lookup", description="Lookup things")) + + image_output = tool.format_result( + MCPToolCall(name="lookup", id="call_1", arguments={}), + _text_image_result(), + ) + structured_output = tool.format_result( + MCPToolCall(name="lookup", id="call_2", arguments={}), + MCPToolResult( + content=[], isError=False, structuredContent={"result": {"type": "text", "text": "ok"}} + ), + ) + + assert image_output == [ + {"role": "tool", "tool_call_id": "call_1", "content": "text output"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Tool returned the following:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,image-bytes"}}, + ], + }, + ] + assert structured_output == {"role": "tool", "tool_call_id": "call_2", "content": "ok"} diff --git a/hud/agents/tests/test_resolver.py b/hud/agents/tests/test_resolver.py deleted file mode 100644 index 05f06b6b7..000000000 --- a/hud/agents/tests/test_resolver.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Tests for model resolution and create_agent.""" - -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest - -from hud.agents import create_agent -from hud.agents.resolver import resolve_cls - - -@pytest.fixture(autouse=True) -def clear_cache() -> None: - """Clear the models cache before each test.""" - import hud.agents.resolver as resolver_module - - resolver_module._models_cache = None - - -# Mock API response data matching the platform backend format -MOCK_MODELS = [ - { - "id": "uuid-1", - "name": "Claude Sonnet 4.6", - "model_name": "claude-sonnet-4-6", - "sdk_agent_type": None, - "provider": {"name": "Anthropic", "default_sdk_agent_type": "claude"}, - }, - { - "id": "uuid-2", - "name": "GPT 5.4", - "model_name": "gpt-5.4", - "sdk_agent_type": None, - "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"}, - }, - { - "id": "uuid-3", - "name": "Operator", - "model_name": "computer-use-preview", - "sdk_agent_type": "operator", - "provider": {"name": "OpenAI", "default_sdk_agent_type": "openai"}, - }, - { - "id": "uuid-4", - "name": "Gemini 3 Pro", - "model_name": "gemini-3-pro-preview", - "sdk_agent_type": None, - "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"}, - }, - { - "id": "uuid-5", - "name": "Gemini 2.5 Computer Use Preview", - "model_name": "gemini-2.5-computer-use-preview", - "sdk_agent_type": "gemini_cua", - "provider": {"name": "Gemini", "default_sdk_agent_type": "gemini"}, - }, - { - "id": "uuid-6", - "name": "Grok 4.1 Fast", - "model_name": "grok-4-1-fast", - "sdk_agent_type": None, - "provider": {"name": "xAI", "default_sdk_agent_type": "openai_compatible"}, - }, -] - - -class TestResolveCls: - """Tests for resolve_cls function.""" - - def test_resolves_known_agent_type(self) -> None: - """Known AgentType strings resolve to their class.""" - from hud.agents.claude import ClaudeAgent - - cls, gateway_info = resolve_cls("claude") - assert cls == ClaudeAgent - assert gateway_info is None - - def test_resolves_openai(self) -> None: - """Resolves 'openai' to OpenAIAgent.""" - from hud.agents import OpenAIAgent - - cls, _gateway_info = resolve_cls("openai") - assert cls == OpenAIAgent - - def test_resolves_gemini(self) -> None: - """Resolves 'gemini' to GeminiAgent.""" - from hud.agents.gemini import GeminiAgent - - cls, _gateway_info = resolve_cls("gemini") - assert cls == GeminiAgent - - def test_unknown_model_raises(self) -> None: - """Unknown model raises ValueError.""" - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - pytest.raises(ValueError, match="not found"), - ): - resolve_cls("unknown-model-xyz-123") - - def test_resolves_claude_model(self) -> None: - """Resolves Claude model to ClaudeAgent via sdk_agent_type.""" - from hud.agents.claude import ClaudeAgent - - with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS): - cls, info = resolve_cls("claude-sonnet-4-6") - assert cls == ClaudeAgent - assert info is not None - assert info["model_name"] == "claude-sonnet-4-6" - - def test_resolves_openai_model(self) -> None: - """Resolves OpenAI model to OpenAIAgent via sdk_agent_type.""" - from hud.agents import OpenAIAgent - - with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS): - cls, info = resolve_cls("gpt-5.4") - assert cls == OpenAIAgent - assert info is not None - - def test_operator_model_is_not_supported(self) -> None: - """Stale gateway Operator models fail with a clear message.""" - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - pytest.raises(ValueError, match="Operator agent is no longer supported"), - ): - resolve_cls("computer-use-preview") - - def test_resolves_gemini_model(self) -> None: - """Resolves Gemini model to GeminiAgent via provider default.""" - from hud.agents.gemini import GeminiAgent - - with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS): - cls, info = resolve_cls("gemini-3-pro-preview") - assert cls == GeminiAgent - assert info is not None - - def test_gemini_cua_model_is_not_supported(self) -> None: - """Stale gateway Gemini CUA models fail with a clear message.""" - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - pytest.raises(ValueError, match="Gemini CUA agent is no longer supported"), - ): - resolve_cls("gemini-2.5-computer-use-preview") - - def test_resolves_openai_compatible_model(self) -> None: - """Resolves OpenAI-compatible model to OpenAIChatAgent via provider default.""" - from hud.agents.openai_compatible import OpenAIChatAgent - - with patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS): - cls, info = resolve_cls("grok-4-1-fast") - assert cls == OpenAIChatAgent - assert info is not None - - def test_unsupported_sdk_agent_type_is_rejected(self) -> None: - """Unsupported sdk_agent_type values are not silently remapped.""" - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - pytest.raises(ValueError, match="Operator agent is no longer supported"), - ): - resolve_cls("computer-use-preview") - - -class TestCreateAgent: - """Tests for create_agent function - gateway-only.""" - - def test_creates_with_gateway_client(self) -> None: - """create_agent always uses gateway routing.""" - from hud.agents import OpenAIAgent - - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - patch.object(OpenAIAgent, "create") as mock_create, - patch("hud.agents.gateway.build_gateway_client") as mock_build_client, - ): - mock_client = MagicMock() - mock_build_client.return_value = mock_client - mock_agent = MagicMock() - mock_create.return_value = mock_agent - - agent = create_agent("gpt-5.4") - - call_kwargs = mock_create.call_args.kwargs - assert call_kwargs["model"] == "gpt-5.4" - assert "model_client" in call_kwargs - assert agent == mock_agent - - def test_passes_kwargs_to_create(self) -> None: - """Extra kwargs are passed to agent.create().""" - from hud.agents import OpenAIAgent - - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - patch.object(OpenAIAgent, "create") as mock_create, - patch("hud.agents.gateway.build_gateway_client"), - ): - mock_create.return_value = MagicMock() - - create_agent("gpt-5.4", temperature=0.5, max_tokens=1000) - - call_kwargs = mock_create.call_args.kwargs - assert call_kwargs["temperature"] == 0.5 - assert call_kwargs["max_tokens"] == 1000 - - def test_known_agent_type_also_uses_gateway(self) -> None: - """Even 'claude' string uses gateway (it's a gateway shortcut).""" - from hud.agents.claude import ClaudeAgent - - with ( - patch.object(ClaudeAgent, "create") as mock_create, - patch("hud.agents.gateway.build_gateway_client") as mock_build_client, - ): - mock_client = MagicMock() - mock_build_client.return_value = mock_client - mock_create.return_value = MagicMock() - - create_agent("claude") - - mock_build_client.assert_called_once() - call_kwargs = mock_create.call_args.kwargs - assert "model_client" in call_kwargs - - def test_uses_correct_provider_from_gateway_info(self) -> None: - """Provider name is extracted from gateway info.""" - from hud.agents.claude import ClaudeAgent - - with ( - patch("hud.agents.resolver._fetch_gateway_models", return_value=MOCK_MODELS), - patch.object(ClaudeAgent, "create") as mock_create, - patch("hud.agents.gateway.build_gateway_client") as mock_build_client, - ): - mock_build_client.return_value = MagicMock() - mock_create.return_value = MagicMock() - - create_agent("claude-sonnet-4-6") - - mock_build_client.assert_called_once_with("Anthropic") - - -class TestBuildGatewayClient: - """Tests for build_gateway_client function.""" - - def test_builds_anthropic_client(self) -> None: - """Builds AsyncAnthropic for anthropic provider.""" - from hud.agents.gateway import build_gateway_client - - with patch("hud.settings.settings") as mock_settings: - mock_settings.api_key = "test-key" - mock_settings.hud_gateway_url = "https://gateway.hud.ai" - - with patch("anthropic.AsyncAnthropic") as mock_client_cls: - build_gateway_client("anthropic") - mock_client_cls.assert_called_once() - - def test_builds_openai_client_for_openai(self) -> None: - """Builds AsyncOpenAI for openai provider.""" - from hud.agents.gateway import build_gateway_client - - with patch("hud.settings.settings") as mock_settings: - mock_settings.api_key = "test-key" - mock_settings.hud_gateway_url = "https://gateway.hud.ai" - - with patch("openai.AsyncOpenAI") as mock_client_cls: - build_gateway_client("openai") - mock_client_cls.assert_called_once() - - def test_builds_openai_client_for_unknown(self) -> None: - """Builds AsyncOpenAI for unknown providers (openai-compatible).""" - from hud.agents.gateway import build_gateway_client - - with patch("hud.settings.settings") as mock_settings: - mock_settings.api_key = "test-key" - mock_settings.hud_gateway_url = "https://gateway.hud.ai" - - with patch("openai.AsyncOpenAI") as mock_client_cls: - build_gateway_client("together") - mock_client_cls.assert_called_once() diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py deleted file mode 100644 index c818e3a7b..000000000 --- a/hud/agents/tests/test_run_eval.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Tests for MCPAgent.run() with EvalContext.""" - -from __future__ import annotations - -from typing import Any, ClassVar - -import pytest -from mcp import types - -from hud.agents import MCPAgent -from hud.agents.base import BaseCreateParams -from hud.environment.router import ToolRouter -from hud.eval.context import EvalContext -from hud.types import AgentType, BaseAgentConfig, InferenceResult, MCPToolCall, MCPToolResult - - -class MockConfig(BaseAgentConfig): - model_name: str = "MockAgent" - model: str = "mock-model" - - -class MockCreateParams(BaseCreateParams, MockConfig): - pass - - -class MockMCPAgent(MCPAgent): - """Mock agent for testing run().""" - - metadata: ClassVar[dict[str, Any] | None] = {} - config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig - - @classmethod - def agent_type(cls) -> AgentType: - """Return the AgentType for the mock agent.""" - return AgentType.OPENAI - - def __init__(self, **kwargs: Any) -> None: - params = MockCreateParams(**kwargs) - super().__init__(params) - self._response = InferenceResult(content="Test response", tool_calls=[], done=True) - - def set_response(self, response: InferenceResult) -> None: - self._response = response - - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: - return self._response - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[dict[str, Any]]: - return [{"role": "tool", "content": str(r)} for r in tool_results] - - async def get_system_messages(self) -> list[Any]: - return [] - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - return [{"type": "text", "text": getattr(b, "text")} for b in blocks if hasattr(b, "text")] - - -class MockEvalContext(EvalContext): - """Mock EvalContext for testing - inherits from real EvalContext.""" - - def __init__(self, prompt: str = "Test prompt", tools: list[types.Tool] | None = None) -> None: - # Core attributes - self.prompt = prompt - self._tools = tools or [types.Tool(name="test_tool", description="Test", inputSchema={})] - self._submitted: str | dict[str, Any] | None = None - self.reward: float | None = None - self._initialized = True - - # Environment attributes - self._router = ToolRouter() - - # EvalContext attributes - self._task = None - self.trace_id = "test-trace-id" - self.eval_name = "test-eval" - self.job_id: str | None = None - self.group_id: str | None = None - self.index = 0 - self.variants: dict[str, Any] = {} - self.answer: str | dict[str, Any] | None = None - self.system_prompt: str | None = None - self.error: BaseException | None = None - self.metadata: dict[str, Any] = {} - self.results: list[Any] = [] - self._is_summary = False - - def as_tools(self) -> list[types.Tool]: - return self._tools - - @property - def has_scenario(self) -> bool: - return True - - async def list_tools(self) -> list[types.Tool]: - return self._tools - - async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: - # Handle tuple format (name, args) - if isinstance(call, tuple): - name = call[0] - elif hasattr(call, "name"): - name = call.name - else: - name = str(call) - return MCPToolResult( - content=[types.TextContent(type="text", text=f"Result from {name}")], - isError=False, - ) - - async def submit(self, answer: str | dict[str, Any]) -> None: - self._submitted = answer - - -class TestRun: - """Tests for MCPAgent.run() with EvalContext.""" - - @pytest.mark.asyncio - async def test_run_basic(self) -> None: - """Test basic run() flow.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - - result = await agent.run(ctx) - - assert result.done - assert result.content == "Test response" - assert ctx._submitted == "Test response" - - @pytest.mark.asyncio - async def test_run_no_prompt_raises(self) -> None: - """Test run() raises when prompt is not set.""" - ctx = MockEvalContext(prompt="") - agent = MockMCPAgent() - - with pytest.raises(ValueError, match="prompt is not set"): - await agent.run(ctx) - - @pytest.mark.asyncio - async def test_run_wrong_type_raises(self) -> None: - """Test run() raises TypeError for non-EvalContext.""" - agent = MockMCPAgent() - - with pytest.raises(TypeError, match="must be EvalContext"): - await agent.run("not an eval context") # type: ignore[arg-type] - - @pytest.mark.asyncio - async def test_run_clears_ctx(self) -> None: - """Test run() clears ctx after completion.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - - await agent.run(ctx) - assert agent.ctx is None - - @pytest.mark.asyncio - async def test_run_no_submit_on_empty_content(self) -> None: - """Test run() doesn't submit when content is empty.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - agent.set_response(InferenceResult(content="", tool_calls=[], done=True)) - - await agent.run(ctx) - assert ctx._submitted is None - - @pytest.mark.asyncio - async def test_run_initializes_tools(self) -> None: - """Test run() initializes tools from context.""" - ctx = MockEvalContext( - prompt="Do the task", - tools=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ], - ) - agent = MockMCPAgent() - - await agent.run(ctx) - - assert agent._initialized - # After cleanup, ctx is None but tools were discovered - - -class TestRunCitations: - """Tests for citation flow through run() -> Trace -> submit().""" - - @pytest.mark.asyncio - async def test_run_submits_plain_string_without_citations(self) -> None: - """When no citations, submit() receives a plain string.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - agent.set_response(InferenceResult(content="answer", done=True)) - - await agent.run(ctx) - - assert ctx._submitted == "answer" - - @pytest.mark.asyncio - async def test_run_submits_dict_with_citations(self) -> None: - """When citations are present, submit() receives a dict.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - agent.set_response( - InferenceResult( - content="answer with sources", - done=True, - citations=[ - {"type": "url_citation", "source": "https://example.com", "title": "Ex"}, - ], - ) - ) - - await agent.run(ctx) - - assert isinstance(ctx._submitted, dict) - assert ctx._submitted["content"] == "answer with sources" - assert len(ctx._submitted["citations"]) == 1 - assert ctx._submitted["citations"][0]["source"] == "https://example.com" - - @pytest.mark.asyncio - async def test_trace_carries_citations_from_inference(self) -> None: - """Trace.citations is populated from the final InferenceResult.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - citations = [ - {"type": "grounding", "source": "https://a.com", "text": "fact"}, - {"type": "url_citation", "source": "https://b.com", "title": "B"}, - ] - agent.set_response( - InferenceResult( - content="sourced answer", - done=True, - citations=citations, - ) - ) - - trace = await agent.run(ctx) - - assert len(trace.citations) == 2 - assert trace.citations[0]["source"] == "https://a.com" - assert trace.citations[1]["source"] == "https://b.com" - - @pytest.mark.asyncio - async def test_trace_empty_citations_on_no_citations(self) -> None: - """Trace.citations is empty when InferenceResult has no citations.""" - ctx = MockEvalContext(prompt="Do the task") - agent = MockMCPAgent() - agent.set_response(InferenceResult(content="plain answer", done=True)) - - trace = await agent.run(ctx) - - assert trace.citations == [] - - @pytest.mark.asyncio - async def test_trace_empty_citations_on_error(self) -> None: - """Trace.citations is empty when agent errors out.""" - - class FailingAgent(MockMCPAgent): - async def get_response(self, messages: list[dict[str, Any]]) -> InferenceResult: - raise RuntimeError("boom") - - ctx = MockEvalContext(prompt="Do the task") - agent = FailingAgent() - - trace = await agent.run(ctx) - - assert trace.isError is True - assert trace.citations == [] diff --git a/hud/agents/tests/test_shared_eval_boundary.py b/hud/agents/tests/test_shared_eval_boundary.py new file mode 100644 index 000000000..9c2c98f21 --- /dev/null +++ b/hud/agents/tests/test_shared_eval_boundary.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from mcp import types + +from hud.agents.tests.conftest import ( + HarnessEvalContext, + RoutingHarnessTools, + ScriptedAgent, + mcp_tool, + text_prompt, + text_result, +) +from hud.types import AgentResponse, MCPToolCall, Trace + + +@pytest.mark.asyncio +async def test_eval_run_submits_final_content() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + agent = ScriptedAgent([AgentResponse(content="answer", done=True)]) + + result = await ctx.run_agent(agent) + + assert result.content == "answer" + assert ctx.submitted == "answer" + + +@pytest.mark.asyncio +async def test_eval_run_submits_citations_with_content() -> None: + citations = [{"type": "url", "source": "https://example.com"}] + ctx = HarnessEvalContext(prompt="Do the task") + agent = ScriptedAgent( + [AgentResponse(content="answer with sources", citations=citations, done=True)] + ) + + result = await ctx.run_agent(agent) + + assert result.citations == citations + assert ctx.submitted == {"content": "answer with sources", "citations": citations} + + +@pytest.mark.asyncio +async def test_eval_run_does_not_submit_empty_content() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + agent = ScriptedAgent([AgentResponse(content="", done=True)]) + + result = await ctx.run_agent(agent) + + assert result.content == "" + assert ctx.submitted is None + + +@pytest.mark.asyncio +async def test_eval_run_records_error_without_submission() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + agent = ScriptedAgent([AgentResponse(content="bad", isError=True, done=True)]) + + result = await ctx.run_agent(agent) + + assert result.isError is True + assert isinstance(ctx.error, Exception) + assert str(ctx.error) == "bad" + assert ctx.submitted is None + + +@pytest.mark.asyncio +async def test_eval_run_requires_prompt_when_no_conversation_or_scenario_messages() -> None: + ctx = HarnessEvalContext(prompt="") + agent = ScriptedAgent([AgentResponse(content="unused", done=True)]) + + with pytest.raises(ValueError, match=r"ctx\.prompt is not set"): + await ctx.run_agent(agent) + + +@pytest.mark.asyncio +async def test_prompt_messages_prefer_scenario_messages_over_conversation_and_prompt() -> None: + scenario_message = text_prompt("scenario message", role="assistant") + ctx = HarnessEvalContext(prompt="fallback prompt") + ctx.conversation = [{"role": "user", "content": "conversation message"}] + ctx.set_scenario_messages([scenario_message]) + agent = ScriptedAgent([AgentResponse(content="answer", done=True)]) + + await ctx.run_agent(agent) + + assert agent.seen_messages[0] == [{"role": "assistant", "content": "scenario message"}] + + +@pytest.mark.asyncio +async def test_prompt_messages_use_conversation_before_prompt() -> None: + ctx = HarnessEvalContext(prompt="fallback prompt") + ctx.conversation = [ + {"role": "assistant", "content": "previous"}, + {"role": "user", "content": "next"}, + ] + agent = ScriptedAgent([AgentResponse(content="answer", done=True)]) + + await ctx.run_agent(agent) + + assert agent.seen_messages[0] == [ + {"role": "assistant", "content": "previous"}, + {"role": "user", "content": "next"}, + ] + + +@pytest.mark.asyncio +async def test_eval_run_passes_citation_flag_to_agent() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + ctx.enable_citations = True + agent = ScriptedAgent([AgentResponse(content="answer", done=True)]) + + await ctx.run_agent(agent) + + assert agent.enable_citations is True + + +@pytest.mark.asyncio +async def test_eval_run_executes_environment_tool_and_submits_final_answer() -> None: + ctx = HarnessEvalContext( + prompt="Use a tool", + tools=[mcp_tool("lookup")], + tool_results={"lookup": text_result("looked up")}, + ) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={"q": "hud"})]), + AgentResponse(content="answer", done=True), + ] + ) + + result = await ctx.run_agent(agent) + + assert result.content == "answer" + assert ctx.submitted == "answer" + assert [(call.name, call.arguments) for call in ctx.environment.calls] == [ + ("lookup", {"q": "hud"}) + ] + + +@pytest.mark.asyncio +async def test_eval_tool_metadata_routes_native_provider_tool_to_environment_tool() -> None: + ctx = HarnessEvalContext( + prompt="Use shell", + tools=[mcp_tool("run_shell")], + metadata={"capabilities": {"shell": "run_shell"}}, + ) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="shell", arguments={"command": "pwd"})]), + AgentResponse(content="done", done=True), + ], + tools_factory=RoutingHarnessTools, + ) + + result = await ctx.run_agent(agent) + + assert result.content == "done" + assert [(call.name, call.arguments) for call in ctx.environment.calls] == [ + ("run_shell", {"command": "pwd"}) + ] + + +@pytest.mark.asyncio +async def test_eval_run_passes_max_steps_to_agent_run() -> None: + ctx = HarnessEvalContext(prompt="Use a tool", tools=[mcp_tool("lookup")]) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]), + AgentResponse(content="too late", done=True), + ] + ) + + result = await ctx.run_agent(agent, max_steps=1) + + assert result.content is None + assert ctx.submitted is None + assert [(call.name, call.arguments) for call in ctx.environment.calls] == [("lookup", {})] + + +@pytest.mark.asyncio +async def test_eval_run_records_agent_step_error_on_context() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + agent = ScriptedAgent([RuntimeError("agent failed")]) + + result = await ctx.run_agent(agent) + + assert result.isError is True + assert isinstance(ctx.error, Exception) + assert str(ctx.error) == "agent failed" + assert ctx.submitted is None + + +@pytest.mark.asyncio +async def test_submit_result_error_prefers_info_error_message() -> None: + ctx = HarnessEvalContext(prompt="Do the task") + + result = Trace(isError=True, content="fallback", info={"error": "specific"}) + + await ctx.submit_result(result) + + assert isinstance(ctx.error, Exception) + assert str(ctx.error) == "specific" + + +def test_tool_metadata_accepts_legacy_capabilities_shape() -> None: + ctx = HarnessEvalContext( + prompt="Do the task", + metadata={"capabilities": {"computer": "computer"}}, + ) + + metadata = ctx.tool_metadata_for_run() + + assert metadata == {"capabilities": {"computer": "computer"}} + + +def test_tool_metadata_prefers_environment_capabilities_shape() -> None: + environment_capabilities: dict[str, Any] = {"capabilities": {"computer": {"tool": "computer"}}} + ctx = HarnessEvalContext( + prompt="Do the task", + metadata={"environment_capabilities": environment_capabilities}, + ) + + metadata = ctx.tool_metadata_for_run() + + assert metadata is environment_capabilities + + +def test_prompt_falls_back_to_plain_user_message() -> None: + ctx = HarnessEvalContext(prompt="hello") + + messages = ctx.prompt_messages() + + assert messages == [ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="hello"), + ) + ] diff --git a/hud/agents/tests/test_shared_run_loop.py b/hud/agents/tests/test_shared_run_loop.py new file mode 100644 index 000000000..d64bb4e62 --- /dev/null +++ b/hud/agents/tests/test_shared_run_loop.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from hud.agents.base import AgentContext +from hud.agents.tests.conftest import ( + HarnessConfig, + RecordingToolEnvironment, + ScriptedAgent, + mcp_tool, + text_prompt, + text_result, +) +from hud.types import AgentResponse, MCPToolCall + + +@pytest.mark.asyncio +async def test_run_returns_final_response_without_tools() -> None: + agent = ScriptedAgent([AgentResponse(content="done", done=True)]) + + result = await agent.run(AgentContext(messages=[text_prompt("do it")])) + + assert result.done is True + assert result.isError is False + assert result.content == "done" + assert agent.seen_messages == [[{"role": "user", "content": "do it"}]] + + +@pytest.mark.asyncio +async def test_run_executes_tool_call_and_continues_with_tool_result() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("found it")}, + ) + agent = ScriptedAgent( + [ + AgentResponse( + tool_calls=[MCPToolCall(name="lookup", arguments={"query": "thing"})], + done=False, + ), + AgentResponse(content="answer", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("find thing")], tool_client=environment.client) + ) + + assert result.content == "answer" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "thing"}) + ] + assert agent.seen_messages[1][-1] == { + "role": "tool", + "name": "lookup", + "content": "found it", + "is_error": False, + } + + +@pytest.mark.asyncio +async def test_run_supports_multiple_tool_steps_before_final_answer() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("first"), mcp_tool("second")], + results={"first": text_result("one"), "second": text_result("two")}, + ) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="first", arguments={})]), + AgentResponse(tool_calls=[MCPToolCall(name="second", arguments={"n": 2})]), + AgentResponse(content="finished", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("go")], tool_client=environment.client) + ) + + assert result.content == "finished" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("first", {}), + ("second", {"n": 2}), + ] + assert len(agent.seen_messages) == 3 + + +@pytest.mark.asyncio +async def test_run_preserves_same_turn_tool_call_order() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("first"), mcp_tool("second")], + results={"first": text_result("one"), "second": text_result("two")}, + ) + agent = ScriptedAgent( + [ + AgentResponse( + tool_calls=[ + MCPToolCall(name="first", arguments={"order": 1}), + MCPToolCall(name="second", arguments={"order": 2}), + ] + ), + AgentResponse(content="finished", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("call both")], tool_client=environment.client) + ) + + assert result.content == "finished" + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("first", {"order": 1}), + ("second", {"order": 2}), + ] + assert agent.seen_messages[1][-2:] == [ + {"role": "tool", "name": "first", "content": "one", "is_error": False}, + {"role": "tool", "name": "second", "content": "two", "is_error": False}, + ] + + +@pytest.mark.asyncio +async def test_unlimited_max_steps_runs_until_final_answer() -> None: + environment = RecordingToolEnvironment([mcp_tool("loop")]) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="loop", arguments={"step": 1})]), + AgentResponse(tool_calls=[MCPToolCall(name="loop", arguments={"step": 2})]), + AgentResponse(content="done", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("loop")], tool_client=environment.client), + max_steps=-1, + ) + + assert result.content == "done" + assert [call.arguments for call in environment.calls] == [{"step": 1}, {"step": 2}] + + +@pytest.mark.asyncio +async def test_tool_timeout_stops_run_with_error_trace() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("slow")], + results={"slow": TimeoutError("too slow")}, + ) + agent = ScriptedAgent([AgentResponse(tool_calls=[MCPToolCall(name="slow", arguments={})])]) + + result = await agent.run( + AgentContext(messages=[text_prompt("try slow")], tool_client=environment.client) + ) + + assert result.isError is True + assert result.info["error"] == "too slow" + assert [(call.name, call.arguments) for call in environment.calls] == [("slow", {})] + + +@pytest.mark.asyncio +async def test_tool_errors_are_returned_to_the_model_as_error_results() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": RuntimeError("backend exploded")}, + ) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]), + AgentResponse(content="recovered", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("try")], tool_client=environment.client) + ) + + assert result.content == "recovered" + assert agent.seen_messages[1][-1]["is_error"] is True + assert agent.seen_messages[1][-1]["content"] == "backend exploded" + + +@pytest.mark.asyncio +async def test_missing_tool_client_turns_tool_call_into_error_trace() -> None: + agent = ScriptedAgent([AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})])]) + + result = await agent.run(AgentContext(messages=[text_prompt("call lookup")])) + + assert result.isError is True + assert result.info["error"] == "call_tool callback is required to execute tool calls" + + +@pytest.mark.asyncio +async def test_max_steps_caps_tool_loop() -> None: + environment = RecordingToolEnvironment([mcp_tool("lookup")]) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]), + AgentResponse(content="should not be reached", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("loop")], tool_client=environment.client), + max_steps=1, + ) + + assert result.done is True + assert result.content is None + assert len(environment.calls) == 1 + assert len(agent.seen_messages) == 1 + + +@pytest.mark.asyncio +async def test_auto_respond_can_continue_after_a_done_response( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[str | None] = [] + + async def continue_once(content: str | None, *, enabled: bool) -> object: + calls.append(content) + assert enabled is True + if len(calls) > 1: + return None + return text_prompt("continue") + + monkeypatch.setattr("hud.agents.base.auto_respond", continue_once) + agent = ScriptedAgent( + [ + AgentResponse(content="need input", done=True), + AgentResponse(content="final", done=True), + ], + config=HarnessConfig(auto_respond=True), + ) + + result = await agent.run(AgentContext(messages=[text_prompt("start")])) + + assert result.content == "final" + assert calls == ["need input", "final"] + assert agent.seen_messages[1][-1] == {"role": "user", "content": "continue"} + + +@pytest.mark.asyncio +async def test_model_step_exception_returns_error_trace() -> None: + agent = ScriptedAgent([RuntimeError("model failed")]) + + result = await agent.run(AgentContext(messages=[text_prompt("start")])) + + assert result.done is True + assert result.isError is True + assert result.content == "model failed" + + +@pytest.mark.asyncio +async def test_keyboard_interrupt_returns_interrupted_trace() -> None: + agent = ScriptedAgent([KeyboardInterrupt()]) + + result = await agent.run(AgentContext(messages=[text_prompt("start")])) + + assert result.isError is True + assert result.content == "Interrupted by user" + assert result.info["error"] == "Interrupted by user" + + +@pytest.mark.asyncio +async def test_cancelled_run_returns_cancelled_trace() -> None: + agent = ScriptedAgent([asyncio.CancelledError()]) + + result = await agent.run(AgentContext(messages=[text_prompt("start")])) + + assert result.isError is True + assert result.content == "Cancelled" + assert result.info["error"] == "Cancelled" + + +@pytest.mark.asyncio +async def test_trace_messages_include_provider_history_before_stop() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("found")}, + ) + agent = ScriptedAgent( + [ + AgentResponse(tool_calls=[MCPToolCall(name="lookup", arguments={})]), + AgentResponse(content="done", done=True), + ] + ) + + result = await agent.run( + AgentContext(messages=[text_prompt("start")], tool_client=environment.client) + ) + + assert result.content == "done" + assert result.messages == [ + {"role": "user", "content": "start"}, + {"role": "tool", "name": "lookup", "content": "found", "is_error": False}, + ] diff --git a/hud/agents/tests/test_shared_tool_registry.py b/hud/agents/tests/test_shared_tool_registry.py new file mode 100644 index 000000000..760af5e7b --- /dev/null +++ b/hud/agents/tests/test_shared_tool_registry.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +import pytest + +from hud.agents.tests.conftest import ( + RecordingToolEnvironment, + RoutingHarnessTools, + mcp_tool, + text_result, +) +from hud.agents.tools.capabilities import discover_environment_capabilities +from hud.types import MCPToolCall + +if TYPE_CHECKING: + from hud.agents.tools import ToolMetadata + + +@pytest.mark.asyncio +async def test_generic_tool_call_routes_to_matching_environment_tool() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": text_result("found")}, + ) + agent_tools = RoutingHarnessTools() + agent_tools.prepare(model="test-model", tools=environment.tools) + + outputs = await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="lookup", arguments={"query": "hud"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("lookup", {"query": "hud"}) + ] + assert outputs == [{"role": "tool", "name": "lookup", "content": "found", "is_error": False}] + + +@pytest.mark.asyncio +async def test_capability_metadata_routes_provider_tool_to_environment_tool() -> None: + environment = RecordingToolEnvironment([mcp_tool("run_shell")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare( + model="test-model", + tools=environment.tools, + tool_metadata={"capabilities": {"shell": "run_shell"}}, + ) + + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="shell", arguments={"command": "pwd"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("run_shell", {"command": "pwd"}) + ] + + +@pytest.mark.asyncio +async def test_name_fallback_routes_native_tool_when_metadata_is_absent() -> None: + environment = RecordingToolEnvironment([mcp_tool("bash")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare(model="test-model", tools=environment.tools) + + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="shell", arguments={"command": "echo hi"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("bash", {"command": "echo hi"}) + ] + + +@pytest.mark.asyncio +async def test_grouped_capability_metadata_routes_to_the_selected_environment_tool() -> None: + environment = RecordingToolEnvironment([mcp_tool("read"), mcp_tool("grep")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare( + model="test-model", + tools=environment.tools, + tool_metadata={"capabilities": {"filesystem": {"tools": {"read": "read", "grep": "grep"}}}}, + ) + + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="read_file", arguments={"path": "README.md"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("read", {"path": "README.md"}) + ] + + +@pytest.mark.asyncio +async def test_native_tool_takes_precedence_over_generic_tool_with_same_environment_name() -> None: + environment = RecordingToolEnvironment([mcp_tool("bash"), mcp_tool("lookup")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare(model="test-model", tools=environment.tools) + + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="shell", arguments={"command": "whoami"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("bash", {"command": "whoami"}) + ] + with pytest.raises(KeyError): + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="bash", arguments={"command": "whoami"}), + ) + + +@pytest.mark.asyncio +async def test_unknown_provider_tool_fails_before_environment_execution() -> None: + environment = RecordingToolEnvironment([mcp_tool("lookup")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare(model="test-model", tools=environment.tools) + + with pytest.raises(KeyError): + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="missing", arguments={}), + ) + + assert environment.calls == [] + + +@pytest.mark.asyncio +async def test_timeout_error_propagates_to_run_loop_boundary() -> None: + environment = RecordingToolEnvironment( + [mcp_tool("lookup")], + results={"lookup": TimeoutError("tool timed out")}, + ) + agent_tools = RoutingHarnessTools() + agent_tools.prepare(model="test-model", tools=environment.tools) + + with pytest.raises(TimeoutError, match="tool timed out"): + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="lookup", arguments={}), + ) + + +def test_invalid_capability_metadata_fails_at_the_boundary() -> None: + with pytest.raises(ValueError, match="Invalid capability metadata"): + discover_environment_capabilities( + [mcp_tool("lookup")], + tool_metadata=cast( + "ToolMetadata", + {"capabilities": {"lookup": {"unexpected": "shape"}}}, + ), + ) + + +@pytest.mark.asyncio +async def test_stale_capability_metadata_falls_back_to_available_tool_names() -> None: + environment = RecordingToolEnvironment([mcp_tool("bash")]) + agent_tools = RoutingHarnessTools() + agent_tools.prepare( + model="test-model", + tools=environment.tools, + tool_metadata={"capabilities": {"shell": "missing_shell"}}, + ) + + await agent_tools.execute( + environment.call_tool, + MCPToolCall(name="shell", arguments={"command": "pwd"}), + ) + + assert [(call.name, call.arguments) for call in environment.calls] == [ + ("bash", {"command": "pwd"}) + ] diff --git a/hud/agents/tools/__init__.py b/hud/agents/tools/__init__.py index 97f4d670d..116387e86 100644 --- a/hud/agents/tools/__init__.py +++ b/hud/agents/tools/__init__.py @@ -2,30 +2,28 @@ from __future__ import annotations -from .base import AgentTool, AgentToolSpec, CallTool, call_agent_tools, call_tool +from .base import ( + AgentTool, + AgentTools, + AgentToolSpec, +) from .capabilities import ( + CapabilityEntry, EnvironmentCapability, GroupedCapabilityMixin, - capabilities_metadata_from_context, + ToolMetadata, discover_environment_capabilities, ) -from .hosted import ( - HostedTool, - select_hosted_tools, -) -from .registry import AgentToolRegistry +from .hosted import HostedTool __all__ = [ "AgentTool", - "AgentToolRegistry", "AgentToolSpec", - "CallTool", + "AgentTools", + "CapabilityEntry", "EnvironmentCapability", "GroupedCapabilityMixin", "HostedTool", - "call_agent_tools", - "call_tool", - "capabilities_metadata_from_context", + "ToolMetadata", "discover_environment_capabilities", - "select_hosted_tools", ] diff --git a/hud/agents/tools/base.py b/hud/agents/tools/base.py index 2ba5ea806..435027c23 100644 --- a/hud/agents/tools/base.py +++ b/hud/agents/tools/base.py @@ -3,19 +3,37 @@ from __future__ import annotations import fnmatch +import logging from abc import ABC, abstractmethod -from collections.abc import Awaitable, Callable, Mapping -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, TypeVar +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, TypeVar, cast +import mcp.types as types + +from hud.agents.tools.capabilities import discover_environment_capabilities from hud.types import MCPToolCall, MCPToolResult if TYPE_CHECKING: - from hud.agents.base import MCPAgent - from hud.agents.tools.capabilities import EnvironmentCapability + from collections.abc import Mapping + + from hud.agents.tools.capabilities import EnvironmentCapability, ToolMetadata + from hud.agents.tools.hosted import HostedTool +AgentToolParamT_co = TypeVar("AgentToolParamT_co", covariant=True) ToolParamT = TypeVar("ToolParamT") +AgentToolT = TypeVar("AgentToolT", bound="AgentTool[object]") CallTool = Callable[[MCPToolCall], Awaitable[MCPToolResult]] +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ToolClient: + """MCP tools and execution hook available for one agent run.""" + + tools: list[types.Tool] = field(default_factory=list[types.Tool]) + tool_handler: CallTool | None = None + tool_metadata: ToolMetadata | None = None @dataclass(frozen=True) @@ -24,20 +42,21 @@ class AgentToolSpec: api_type: str api_name: str - beta: str | None = None supported_models: tuple[str, ...] | None = None def supports_model(self, model: str | None) -> bool: - if not self.supported_models or not model or model == "unknown": + if not self.supported_models: return True + if not model or model == "unknown": + return False model_lower = model.lower() return any( fnmatch.fnmatch(model_lower, pattern.lower()) for pattern in self.supported_models ) -class AgentTool(ABC, Generic[ToolParamT]): - """Provider-facing tool backed by one environment tool.""" +class AgentTool(ABC, Generic[AgentToolParamT_co]): + """Provider-facing tool owned by an agent harness.""" name: ClassVar[str] capability: ClassVar[str] @@ -46,79 +65,182 @@ def __init__(self, *, env_tool_name: str, spec: AgentToolSpec) -> None: self.env_tool_name = env_tool_name self.spec = spec + @property + def provider_name(self) -> str: + return self.name + + @classmethod + def env_tool_name_for_capability(cls, capability: EnvironmentCapability) -> str | None: + return capability.tool_name + @classmethod def from_capability( cls, capability: EnvironmentCapability, - spec: AgentToolSpec, model: str, - ) -> Self: - del model - return cls(env_tool_name=capability.tool_name, spec=spec) + ) -> Self | None: + spec = cls.default_spec(model) + env_tool_name = cls.env_tool_name_for_capability(capability) + if spec is None or env_tool_name is None: + return None + return cls(env_tool_name=env_tool_name, spec=spec) @classmethod def default_spec(cls, model: str) -> AgentToolSpec | None: """Return the provider spec this agent should use for this capability.""" - del model return None - @property - def required_beta(self) -> str | None: - return self.spec.beta - - async def execute(self, caller: CallTool, arguments: dict[str, Any]) -> MCPToolResult: - """Execute by forwarding to the backing environment tool.""" - return await call_tool(caller, self.env_tool_name, arguments) - - @abstractmethod - def to_params(self) -> ToolParamT: ... + @classmethod + def from_tool(cls, tool: types.Tool) -> Self | None: + """Build a provider tool for a generic environment tool.""" + del tool + return None + async def execute(self, call_tool: CallTool, arguments: dict[str, Any]) -> MCPToolResult: + """Execute an environment-backed tool by forwarding to its MCP tool.""" + return await call_tool(MCPToolCall(name=self.env_tool_name, arguments=arguments)) -async def call_tool( - caller: CallTool, - env_tool_name: str, - arguments: dict[str, Any], -) -> MCPToolResult: - result = await caller(MCPToolCall(name=env_tool_name, arguments=arguments)) - return MCPToolResult(content=result.content, isError=result.isError) + def format_result(self, call: MCPToolCall, result: MCPToolResult) -> Any | None: + """Format a single tool result for the provider continuation turn.""" + del result + logger.warning("Tool '%s' does not implement result formatting.", call.name) + return None + @abstractmethod + def to_params(self) -> AgentToolParamT_co: ... -async def call_agent_tools( - agent: MCPAgent, - agent_tools: Mapping[str, AgentTool[Any]], - tool_call: MCPToolCall | list[MCPToolCall] | None = None, -) -> list[MCPToolResult]: - """Route provider-owned tool calls through adapters, otherwise through MCP.""" - import mcp.types as types - from hud.agents.base import MCPAgent +class AgentTools(dict[str, AgentToolT], Generic[AgentToolT, ToolParamT]): + """Prepared tool state owned by a single agent run.""" - if tool_call is None: - return [] - tool_calls = [tool_call] if isinstance(tool_call, MCPToolCall) else tool_call + native_tool_classes: ClassVar[tuple[type[AgentTool[object]], ...]] = () + function_tool_class: ClassVar[type[AgentTool[object]] | None] = None + name_fallbacks: ClassVar[Mapping[str, tuple[str, ...]]] = {} - async def call_env_tool(call: MCPToolCall) -> MCPToolResult: - return (await MCPAgent.call_tools(agent, call))[0] + def __init__(self) -> None: + super().__init__() + self.params: list[ToolParamT] = [] + self.name_map: dict[str, str] = {} + self.hosted_tools: list[HostedTool[object]] = [] - results: list[MCPToolResult] = [] - for tc in tool_calls: - agent_tool = agent_tools.get(tc.name) - if agent_tool is None: - results.extend(await MCPAgent.call_tools(agent, tc)) - continue + def select_tools( + self, + tools: list[types.Tool], + model: str, + *, + tool_metadata: ToolMetadata | None = None, + ) -> tuple[list[AgentToolT], list[types.Tool]]: + """Split MCP tools into provider-owned and user-defined tools.""" + logger.info("Discovered %s tools: %s", len(tools), ", ".join(tool.name for tool in tools)) + + capabilities = discover_environment_capabilities( + tools, + tool_metadata=tool_metadata, + name_fallbacks=self.name_fallbacks, + ) + agent_tools: list[AgentToolT] = [] + for capability in capabilities.values(): + for raw_tool_cls in self.native_tool_classes: + tool_cls = cast("type[AgentToolT]", raw_tool_cls) + if tool_cls.capability != capability.name: + continue + tool = tool_cls.from_capability(capability, model) + if tool is not None: + agent_tools.append(tool) + agent_tool_names = {tool.env_tool_name for tool in agent_tools} + user_tools = [tool for tool in tools if tool.name not in agent_tool_names] + return agent_tools, user_tools + + def generic_tool( + self, + tool: types.Tool, + ) -> ToolParamT | None: + """Convert an environment MCP tool into provider params.""" + del tool + return None - try: + def prepare( + self, + *, + model: str, + tools: list[types.Tool], + hosted_tools: list[HostedTool[object]] | None = None, + tool_metadata: ToolMetadata | None = None, + ) -> None: + """Prepare a generic provider tool map for an agent run.""" + provider_tools, user_tools = self.select_tools( + tools, + model, + tool_metadata=tool_metadata, + ) + tools_by_name = {tool.provider_name: tool for tool in provider_tools} + installed_names = set(tools_by_name) + self.update(tools_by_name) + self.params.extend(cast("ToolParamT", tool.to_params()) for tool in provider_tools) + self.name_map.update({name: name for name in tools_by_name}) + + selected_hosted_tools: list[HostedTool[object]] = [] + for tool in hosted_tools or []: + if not tool.supports_model(model): + continue + selected_hosted_tools.append(tool) + self.params.append(cast("ToolParamT", tool.to_params())) + self.hosted_tools = selected_hosted_tools + + for tool in user_tools: + if self.function_tool_class is not None: + function_tool_cls = cast("type[AgentToolT]", self.function_tool_class) + agent_tool = function_tool_cls.from_tool(tool) + if agent_tool is None: + continue + self[agent_tool.provider_name] = agent_tool + installed_names.add(agent_tool.provider_name) + self.name_map[tool.name] = agent_tool.provider_name + self.params.append(cast("ToolParamT", agent_tool.to_params())) + continue + generic_tool = self.generic_tool(tool) + if generic_tool is None: + continue + installed_names.add(tool.name) + self.name_map[tool.name] = tool.name + self.params.append(generic_tool) + + tool_names = sorted(installed_names) + logger.info("Agent initialized with %s tools: %s", len(tool_names), ", ".join(tool_names)) + + async def execute( + self, + call_tool: CallTool | None, + tool_call: MCPToolCall | list[MCPToolCall] | None = None, + ) -> list[Any]: + if tool_call is None: + return [] + + if call_tool is None: + raise ValueError("call_tool callback is required to execute tool calls") + + outputs: list[Any] = [] + tool_calls = [tool_call] if isinstance(tool_call, MCPToolCall) else tool_call + for tc in tool_calls: + agent_tool = self[tc.name] arguments = tc.arguments if isinstance(tc.arguments, dict) else {} - results.append(await agent_tool.execute(call_env_tool, arguments)) - except Exception as exc: - agent.console.error_log(f"Agent tool execution failed: {exc}") - results.append( - MCPToolResult( + try: + result = await agent_tool.execute(call_tool, arguments) + except TimeoutError: + raise + except Exception as exc: + logger.exception("Tool execution failed") + result = MCPToolResult( content=[types.TextContent(type="text", text=str(exc))], isError=True, ) - ) - return results + output = agent_tool.format_result(tc, result) + if output is None: + continue + if isinstance(output, list): + outputs.extend(cast("list[Any]", output)) + else: + outputs.append(output) -__all__ = ["AgentTool", "AgentToolSpec", "CallTool", "call_agent_tools", "call_tool"] + return outputs diff --git a/hud/agents/tools/capabilities.py b/hud/agents/tools/capabilities.py index 2dc24d8fc..5c8282e7f 100644 --- a/hud/agents/tools/capabilities.py +++ b/hud/agents/tools/capabilities.py @@ -2,131 +2,108 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Self +from typing import TYPE_CHECKING, ClassVar, TypedDict, cast if TYPE_CHECKING: + from collections.abc import Mapping + from mcp import types as mcp_types - from hud.agents.tools.base import AgentToolSpec + from hud.types import JsonObject, JsonValue +else: + JsonObject = dict[str, object] + JsonValue = object -@dataclass(frozen=True) -class EnvironmentCapability: - """A normalized environment capability bound to one or more MCP tools.""" - name: str +class CapabilityEntry(TypedDict, total=False): + tool: str tool_name: str - tool: mcp_types.Tool - metadata: dict[str, Any] = field(default_factory=dict) + tools: dict[str, str] -def capabilities_metadata_from_context(ctx: Any) -> dict[str, Any] | None: - """Extract an optional env-level capability descriptor from a context.""" - if ctx is None: - return None +class ToolMetadata(TypedDict, total=False): + capabilities: dict[str, str | CapabilityEntry] - direct = getattr(ctx, "environment_capabilities", None) - if isinstance(direct, dict): - return direct - direct = getattr(ctx, "capabilities", None) - if isinstance(direct, dict): - return {"capabilities": direct} - - metadata = getattr(ctx, "metadata", None) - if isinstance(metadata, dict): - for key in ("environment_capabilities", "capabilities"): - value = metadata.get(key) - if isinstance(value, dict): - return value if key == "environment_capabilities" else {"capabilities": value} +class EnvironmentCapability: + """A normalized environment capability bound to one or more MCP tools.""" - return None + def __init__( + self, + *, + name: str, + tool_name: str, + tool: mcp_types.Tool, + metadata: JsonObject | None = None, + ) -> None: + self.name = name + self.tool_name = tool_name + self.tool = tool + self.metadata: JsonObject = metadata or {} def discover_environment_capabilities( tools: list[mcp_types.Tool], *, - env_metadata: dict[str, Any] | None = None, - name_fallbacks: dict[str, tuple[str, ...]] | None = None, + tool_metadata: ToolMetadata | None = None, + name_fallbacks: Mapping[str, tuple[str, ...]] | None = None, ) -> dict[str, EnvironmentCapability]: """Build a normalized capability map from env metadata and tool inventory.""" tool_by_name = {tool.name: tool for tool in tools} capabilities: dict[str, EnvironmentCapability] = {} - _add_env_capabilities(capabilities, tool_by_name, env_metadata) - _add_name_fallback_capabilities(capabilities, tool_by_name, name_fallbacks or {}) - - return capabilities - - -def _add_env_capabilities( - capabilities: dict[str, EnvironmentCapability], - tool_by_name: dict[str, mcp_types.Tool], - env_metadata: dict[str, Any] | None, -) -> None: - if not env_metadata: - return - - raw = env_metadata.get("capabilities", env_metadata) - if not isinstance(raw, dict): - return - - for name, config in raw.items(): - if not isinstance(name, str) or name in capabilities: - continue - tool_name: str | None = None - metadata: dict[str, Any] = {} - if isinstance(config, str): - tool_name = config - elif isinstance(config, dict): - raw_tool = config.get("tool") or config.get("tool_name") - if isinstance(raw_tool, str): - tool_name = raw_tool - metadata = dict(config) - else: - raw_tools = config.get("tools") - if isinstance(raw_tools, dict): - tool_names = { - str(key): value - for key, value in raw_tools.items() - if isinstance(value, str) and value in tool_by_name - } - if tool_names: - tool_name = next(iter(tool_names.values())) - metadata = {**config, "tools": tool_names} - if tool_name is None: - continue - tool = tool_by_name.get(tool_name) - if tool is None: + metadata = tool_metadata or {} + raw_capabilities = cast( + "dict[str, str | CapabilityEntry]", + metadata.get("capabilities", metadata), + ) + for name, config in raw_capabilities.items(): + match config: + case str() as tool_name: + capability_metadata: JsonObject = {} + case {"tool": str() as tool_name}: + capability_metadata = {"tool": tool_name} + case {"tool_name": str() as tool_name}: + capability_metadata = {"tool_name": tool_name} + case {"tools": grouped_tools}: + tool_names: dict[str, JsonValue] = { + str(alias): env_tool_name + for alias, env_tool_name in grouped_tools.items() + if env_tool_name in tool_by_name + } + if not tool_names: + continue + tool_name = str(next(iter(tool_names.values()))) + capability_metadata = {"tools": tool_names} + case _: + raise ValueError(f"Invalid capability metadata for {name!r}: {config!r}") + + if tool_name not in tool_by_name: continue + capabilities[name] = EnvironmentCapability( name=name, - tool_name=tool.name, - tool=tool, - metadata=metadata, + tool_name=tool_name, + tool=tool_by_name[tool_name], + metadata=capability_metadata, ) - -def _add_name_fallback_capabilities( - capabilities: dict[str, EnvironmentCapability], - tool_by_name: dict[str, mcp_types.Tool], - name_fallbacks: dict[str, tuple[str, ...]], -) -> None: - for capability, names in name_fallbacks.items(): + for capability, names in (name_fallbacks or {}).items(): if capability in capabilities: continue matched_tool_names = [name for name in names if name in tool_by_name] - tool_name = matched_tool_names[0] if matched_tool_names else None - if tool_name is None: + if not matched_tool_names: continue - tool = tool_by_name[tool_name] + + tool = tool_by_name[matched_tool_names[0]] capabilities[capability] = EnvironmentCapability( name=capability, tool_name=tool.name, tool=tool, metadata={"tools": {name: name for name in matched_tool_names}}, ) + return capabilities class GroupedCapabilityMixin: @@ -134,37 +111,14 @@ class GroupedCapabilityMixin: env_tool_names: ClassVar[tuple[str, ...]] - if TYPE_CHECKING: - - def __init__(self, *, env_tool_name: str, spec: AgentToolSpec) -> None: ... - @classmethod def env_tool_name_for_capability(cls, capability: EnvironmentCapability) -> str | None: - tools = capability.metadata.get("tools") - if isinstance(tools, dict): - return next( - (tools[name] for name in cls.env_tool_names if isinstance(tools.get(name), str)), - None, - ) + tools_obj = capability.metadata.get("tools") + if isinstance(tools_obj, dict): + tools_map = cast("dict[str, object]", tools_obj) + for name in cls.env_tool_names: + if env_tool_name := tools_map.get(name): + return str(env_tool_name) if capability.tool_name in cls.env_tool_names: return capability.tool_name return None - - @classmethod - def from_capability( - cls, - capability: EnvironmentCapability, - spec: AgentToolSpec, - model: str, - ) -> Self: - del model - env_tool_name = cls.env_tool_name_for_capability(capability) or capability.tool_name - return cls(env_tool_name=env_tool_name, spec=spec) - - -__all__ = [ - "EnvironmentCapability", - "GroupedCapabilityMixin", - "capabilities_metadata_from_context", - "discover_environment_capabilities", -] diff --git a/hud/agents/tools/computer.py b/hud/agents/tools/computer.py new file mode 100644 index 000000000..b8e94c6c6 --- /dev/null +++ b/hud/agents/tools/computer.py @@ -0,0 +1,104 @@ +"""Shared helpers for agent-side computer tools.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from mcp.types import ImageContent, TextContent + +from hud.types import MCPToolCall, MCPToolResult + +if TYPE_CHECKING: + from mcp import types as mcp_types + +CallTool = Callable[[MCPToolCall], Awaitable[MCPToolResult]] + + +@dataclass(frozen=True) +class ComputerToolInfo: + """Computer MCP tool metadata needed by provider adapters.""" + + display_width: int + display_height: int + coordinate_space: int | None + + +def computer_tool_info( + tool: mcp_types.Tool, + *, + default_width: int, + default_height: int, +) -> ComputerToolInfo: + """Resolve the computer contract advertised by the MCP tool.""" + meta = cast("Mapping[str, object]", tool.meta or {}) + resolution = meta.get("resolution") + display_width = default_width + display_height = default_height + + if isinstance(resolution, Mapping): + resolution = cast("Mapping[str, object]", resolution) + width = resolution.get("width") + height = resolution.get("height") + if type(width) is int: + display_width = width + if type(height) is int: + display_height = height + + coordinate_space_raw = meta.get("coordinate_space") + coordinate_space = coordinate_space_raw if type(coordinate_space_raw) is int else None + + return ComputerToolInfo( + display_width=display_width, + display_height=display_height, + coordinate_space=coordinate_space, + ) + + +def computer_error_result(message: str) -> MCPToolResult: + return MCPToolResult(content=[TextContent(type="text", text=message)], isError=True) + + +def result_has_image(result: MCPToolResult) -> bool: + return any(isinstance(block, ImageContent) for block in result.content) + + +def first_image_data(result: MCPToolResult) -> str | None: + for block in result.content: + if isinstance(block, ImageContent): + return block.data + return None + + +def last_image_data(result: MCPToolResult) -> str | None: + for block in reversed(result.content): + if isinstance(block, ImageContent): + return block.data + return None + + +async def execute_computer_calls( + call_tool: CallTool, + *, + env_tool_name: str, + calls: list[dict[str, Any]], + ensure_screenshot: bool, +) -> MCPToolResult: + result = MCPToolResult(content=[], isError=False) + for arguments in calls: + result = await call_tool(MCPToolCall(name=env_tool_name, arguments=arguments)) + if result.isError: + return result + + if ensure_screenshot and not result_has_image(result): + screenshot = await call_tool( + MCPToolCall(name=env_tool_name, arguments={"action": "screenshot"}) + ) + if not screenshot.isError and screenshot.content: + return MCPToolResult( + content=[*result.content, *screenshot.content], + isError=result.isError, + ) + + return result diff --git a/hud/agents/tools/hosted.py b/hud/agents/tools/hosted.py index 160bcab98..e86c3934d 100644 --- a/hud/agents/tools/hosted.py +++ b/hud/agents/tools/hosted.py @@ -2,49 +2,30 @@ from __future__ import annotations +import fnmatch +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar -from .base import AgentToolSpec - -HostedToolParamT = TypeVar("HostedToolParamT") -HostedToolT = TypeVar("HostedToolT", bound="HostedTool[Any]") +HostedToolParamT_co = TypeVar("HostedToolParamT_co", covariant=True) @dataclass(frozen=True, kw_only=True) -class HostedTool(Generic[HostedToolParamT]): +class HostedTool(ABC, Generic[HostedToolParamT_co]): """Provider-side tool activated only through explicit agent config.""" supported_models: tuple[str, ...] | None = None def supports_model(self, model: str | None) -> bool: - spec = AgentToolSpec( - api_type="hosted", - api_name=self.__class__.__name__, - supported_models=self.supported_models, + if not self.supported_models: + return True + if not model or model == "unknown": + return False + model_lower = model.lower() + return any( + fnmatch.fnmatch(model_lower, pattern.lower()) for pattern in self.supported_models ) - return spec.supports_model(model) - def to_params(self) -> HostedToolParamT: + @abstractmethod + def to_params(self) -> HostedToolParamT_co: raise NotImplementedError - - -def select_hosted_tools( - hosted_tools: list[Any], - *, - tool_type: type[HostedToolT], - model: str, -) -> list[HostedToolT]: - """Select explicitly configured hosted tools for one provider/model.""" - selected: list[HostedToolT] = [] - for hosted_tool in hosted_tools: - if not isinstance(hosted_tool, tool_type) or not hosted_tool.supports_model(model): - continue - selected.append(hosted_tool) - return selected - - -__all__ = [ - "HostedTool", - "select_hosted_tools", -] diff --git a/hud/agents/tools/registry.py b/hud/agents/tools/registry.py deleted file mode 100644 index 2de27c52c..000000000 --- a/hud/agents/tools/registry.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Registry support for agent-owned tools.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from .base import AgentTool - -if TYPE_CHECKING: - from hud.agents.tools.capabilities import EnvironmentCapability - -ToolT = TypeVar("ToolT", bound=AgentTool[Any]) - - -@dataclass(frozen=True) -class AgentToolRegistry(Generic[ToolT]): - """Declarative registry for a provider or harness tool family.""" - - tool_classes: tuple[type[ToolT], ...] - name_fallbacks: dict[str, tuple[str, ...]] = field(default_factory=dict) - - @property - def capabilities(self) -> frozenset[str]: - return frozenset(cls.capability for cls in self.tool_classes) - - def tool_for_capability( - self, - capability: EnvironmentCapability, - model: str, - ) -> ToolT | None: - tools = self.tools_for_capability(capability, model) - return tools[0] if tools else None - - def tools_for_capability( - self, - capability: EnvironmentCapability, - model: str, - ) -> list[ToolT]: - tools: list[ToolT] = [] - for tool_cls in self.tool_classes: - if tool_cls.capability != capability.name: - continue - spec = tool_cls.default_spec(model) - if spec is None: - continue - env_tool_name_for_capability = getattr(tool_cls, "env_tool_name_for_capability", None) - if ( - callable(env_tool_name_for_capability) - and env_tool_name_for_capability(capability) is None - ): - continue - tools.append(tool_cls.from_capability(capability, spec, model)) - return tools - - -__all__ = ["AgentToolRegistry"] diff --git a/hud/agents/types.py b/hud/agents/types.py index cb48ed5d9..718fc7ab0 100644 --- a/hud/agents/types.py +++ b/hud/agents/types.py @@ -10,20 +10,22 @@ from pydantic import AliasChoices, BaseModel, ConfigDict, Field -from hud.types import BaseAgentConfig +from hud.agents.tools.hosted import HostedTool # Alias to accept both 'model' and 'checkpoint_name' (backwards compat) _model_alias = AliasChoices("model", "checkpoint_name") -class BaseCreateParams(BaseModel): - """Runtime parameters for agent creation.""" - +class AgentConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) ctx: Any = None # EvalContext or Environment auto_respond: bool = False - verbose: bool = False + system_prompt: str | None = None + hosted_tools: list[HostedTool[object]] = Field(default_factory=list[HostedTool[object]]) + + model_name: str = "Agent" + model: str = Field(default="unknown", validation_alias=_model_alias) # ----------------------------------------------------------------------------- @@ -31,9 +33,7 @@ class BaseCreateParams(BaseModel): # ----------------------------------------------------------------------------- -class ClaudeConfig(BaseAgentConfig): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class ClaudeConfig(AgentConfig): model_name: str = "Claude" model: str = Field(default="claude-sonnet-4-6", validation_alias=_model_alias) model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock @@ -42,23 +42,17 @@ class ClaudeConfig(BaseAgentConfig): validate_api_key: bool = True -class ClaudeCreateParams(BaseCreateParams, ClaudeConfig): - pass - - # ----------------------------------------------------------------------------- # Gemini # ----------------------------------------------------------------------------- -class GeminiConfig(BaseAgentConfig): +class GeminiConfig(AgentConfig): """Configuration for GeminiAgent.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - model_name: str = "Gemini" model: str = Field(default="gemini-3-pro-preview", validation_alias=_model_alias) - model_client: Any = None # genai.Client + model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock temperature: float = 1.0 top_p: float = 0.95 top_k: int = 40 @@ -69,23 +63,17 @@ class GeminiConfig(BaseAgentConfig): include_thoughts: bool = True -class GeminiCreateParams(BaseCreateParams, GeminiConfig): - pass - - # ----------------------------------------------------------------------------- # OpenAI # ----------------------------------------------------------------------------- -class OpenAIConfig(BaseAgentConfig): +class OpenAIConfig(AgentConfig): """Configuration for OpenAIAgent.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - model_name: str = "OpenAI" model: str = Field(default="gpt-5.4", validation_alias=_model_alias) - model_client: Any = None # AsyncOpenAI + model_client: Any = None # AsyncAnthropic | AsyncAnthropicBedrock max_output_tokens: int | None = None temperature: float | None = None reasoning: Any = None # openai Reasoning @@ -96,15 +84,9 @@ class OpenAIConfig(BaseAgentConfig): validate_api_key: bool = True -class OpenAICreateParams(BaseCreateParams, OpenAIConfig): - pass - - -class OpenAIChatConfig(BaseAgentConfig): +class OpenAIChatConfig(AgentConfig): """Configuration for OpenAIChatAgent.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - model_name: str = "OpenAI Chat" model: str = Field(default="gpt-5-mini", validation_alias=_model_alias) checkpoint: str | None = Field( @@ -118,7 +100,3 @@ class OpenAIChatConfig(BaseAgentConfig): api_key: str | None = None base_url: str | None = None completion_kwargs: dict[str, Any] = Field(default_factory=dict) - - -class OpenAIChatCreateParams(BaseCreateParams, OpenAIChatConfig): - pass diff --git a/hud/cli/rl.py b/hud/cli/rl.py index a3831e0a5..538d5a65b 100644 --- a/hud/cli/rl.py +++ b/hud/cli/rl.py @@ -24,7 +24,7 @@ # ============================================================================= -async def _fetch_env_metadata(env_name: str, headers: dict[str, str]) -> dict[str, Any] | None: +async def _fetch_tool_metadata(env_name: str, headers: dict[str, str]) -> dict[str, Any] | None: """Fetch env metadata from mcp-config endpoint. Returns response dict or None.""" url = f"{settings.hud_api_url}/environments/{env_name}/mcp-config" async with httpx.AsyncClient(timeout=15.0) as client: @@ -116,20 +116,20 @@ async def _preflight_validate(tasks: list[Any]) -> None: hud_console.info(f"Preflight: checking {len(env_names)} environment(s)…") - env_metadata: dict[str, dict[str, Any]] = {} + tool_metadata: dict[str, dict[str, Any]] = {} for name in sorted(env_names): - data = await _fetch_env_metadata(name, headers) + data = await _fetch_tool_metadata(name, headers) if data is None: hud_console.error(f"Environment '{name}' not found on platform") hud_console.hint("Deploy it first with: hud deploy") raise typer.Exit(1) - env_metadata[name] = data + tool_metadata[name] = data hud_console.info(f" ✓ {name}") env_scenarios = _extract_scenarios(tasks) for env_name, scenarios in sorted(env_scenarios.items()): - if env_name in env_metadata: - _check_scenarios(env_name, scenarios, env_metadata[env_name]) + if env_name in tool_metadata: + _check_scenarios(env_name, scenarios, tool_metadata[env_name]) hud_console.success("Preflight passed") diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index e46f5c9ca..4d9320d33 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -44,6 +44,7 @@ def __init__( self.error: BaseException | None = None self.metadata: dict[str, Any] = {} self._is_summary = False + self._scenario_sessions = {} def as_tools(self) -> list[types.Tool]: return self._tools diff --git a/hud/cli/utils/version_check.py b/hud/cli/utils/version_check.py index 301053ac6..5ae9d07df 100644 --- a/hud/cli/utils/version_check.py +++ b/hud/cli/utils/version_check.py @@ -232,7 +232,7 @@ def display_update_prompt(console: HUDConsole | None = None) -> None: console: HUDConsole instance for output. If None, creates a new one. """ if console is None: - console = HUDConsole(logger=logger) + console = HUDConsole() try: info = check_for_updates() diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 89b4dc704..49a23e448 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -145,7 +145,7 @@ async def run_dataset( # Create agent using AgentType.cls.create() agent = agent_type.cls.create(**final_agent_params) - await agent.run(ctx, max_steps=max_steps) + await ctx._run(agent, max_steps=max_steps) # Reward is computed by EvalContext.__aexit__ from the scenario evaluate phase. # For parallel execution, results are collected via ctx.results @@ -252,7 +252,7 @@ async def run_single_task( if metadata: ctx.metadata.update(metadata) - result = await agent.run(ctx, max_steps=max_steps) + result = await ctx._run(agent, max_steps=max_steps) # Reward is computed by EvalContext.__aexit__ from the scenario evaluate phase. # Propagate reward from EvalContext (set in __aexit__) to returned Trace diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index b7f064d20..25c75869b 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -38,8 +38,8 @@ class SingleTaskRequest(BaseModel): agent_params: dict[str, Any] = Field( default_factory=dict, description="Agent constructor parameters passed to agent.create(). " - "Should include fields from BaseCreateParams (auto_trace, auto_respond, verbose) " - "plus agent-specific config fields (e.g., checkpoint_name for ClaudeConfig).", + "Should include runtime fields (ctx, auto_respond) plus agent-specific " + "config fields (e.g., checkpoint_name for ClaudeConfig).", ) max_steps: int = Field(default=10, description="Maximum steps allowed for the agent.") job_id: str = Field(description="HUD job identifier for telemetry association.") diff --git a/hud/environment/environment.py b/hud/environment/environment.py index abeff5d8f..3a566475e 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -986,7 +986,7 @@ async def checkout(user_id: str): # Single task via hud.eval async with hud.eval(env("checkout", user_id="alice")) as ctx: - await agent.run(ctx.prompt) + await ctx._run(agent) # Multiple tasks with variants tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py index 17ed1e062..5849afd93 100644 --- a/hud/environment/scenarios.py +++ b/hud/environment/scenarios.py @@ -246,6 +246,18 @@ def _to_prompt_message(item: Any, default_role: str = "user") -> PromptMessage: role=item.role, # type: ignore[arg-type] content=TextContent(type="text", text=str(item.content)), ) + if hasattr(item, "content"): + role = getattr(item, "role", default_role) + content = item.content + if isinstance(content, str): + content = TextContent(type="text", text=content) + elif isinstance(content, TextContent) or hasattr(content, "type"): + pass + elif hasattr(content, "text"): + content = TextContent(type="text", text=str(content.text)) + else: + content = TextContent(type="text", text=str(content)) + return PromptMessage(role=role, content=content) # type: ignore[arg-type] if isinstance(item, str): return PromptMessage( role=default_role, # type: ignore[arg-type] @@ -294,6 +306,22 @@ def _build_answer_for_generator(session: ScenarioSession) -> Any: elif isinstance(raw_answer, str): raw_text = raw_answer raw_citations = [] + text = raw_answer.strip() + if text.startswith("```"): + parts = text.split("```") + if len(parts) >= 3: + text = parts[1].removeprefix("json").strip() + try: + parsed_answer = json.loads(text) + except (json.JSONDecodeError, TypeError): + parsed_answer = None + if isinstance(parsed_answer, dict) and ( + "content" in parsed_answer or "citations" in parsed_answer + ): + content = parsed_answer.get("content", "") + raw_text = content if isinstance(content, str) else json.dumps(content) + citations = parsed_answer.get("citations", []) + raw_citations = [c for c in citations if isinstance(c, dict)] else: raw_text = str(raw_answer) if raw_answer is not None else "" raw_citations = [] @@ -741,10 +769,13 @@ async def run_scenario_setup( # Prompt exists remotely; original setup/rendering error. raise - # Extract prompt text from response + # Extract prompt messages and text from response + prompt_messages = ( + _normalize_prompt_yield(list(result.messages)) if result.messages else None + ) prompt_text: str | None = None - if result.messages: - first_msg = result.messages[0] + if prompt_messages: + first_msg = prompt_messages[0] content = first_msg.content if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr] prompt_text = content.text # type: ignore[union-attr] @@ -793,6 +824,7 @@ async def run_scenario_setup( allowed_tools=allowed_tools_meta, returns_schema=returns_schema_meta, enable_citations=enable_citations_meta, + prompt_messages=prompt_messages, ) self._set_session(session, session_id) diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py index 2133823b8..4e0567940 100644 --- a/hud/environment/tests/test_environment.py +++ b/hud/environment/tests/test_environment.py @@ -321,10 +321,12 @@ async def investigate(issue: str): mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) mock_ctx.__aexit__ = AsyncMock(return_value=None) + from hud.types import Trace + + mock_ctx._run.return_value = Trace(content="subagent output", done=True) mock_run_eval.return_value = mock_ctx mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock(content="subagent output")) mock_create_agent.return_value = mock_agent req_meta = RequestParams.Meta.model_validate({"_hud_trace_id": "trace-from-meta"}) req_context = RequestContext( diff --git a/hud/environment/tests/test_scenarios.py b/hud/environment/tests/test_scenarios.py index a646a3bd7..ca6256970 100644 --- a/hud/environment/tests/test_scenarios.py +++ b/hud/environment/tests/test_scenarios.py @@ -1355,6 +1355,41 @@ async def typed_scenario(): assert isinstance(prompt.meta, dict) assert prompt.meta.get("enable_citations") is True + @pytest.mark.asyncio + async def test_structured_answer_parses_json_wrapped_content_and_citations(self) -> None: + """Structured scenario parsing unwraps model-emitted content/citations JSON.""" + env = Environment("test-env") + + class Answer(BaseModel): + final: str + + captured = None + + @env.scenario("typed", returns=Answer, enable_citations=True) + async def typed_scenario(): + nonlocal captured + captured = yield "Prompt" + yield 1.0 + + await env.run_scenario_setup("typed", {}) + await env.submit( + "typed", + """```json +{ + "content": {"final": "done"}, + "citations": [ + {"type": "url_citation", "source": "https://example.com", "text": "source"} + ] +} +```""", + ) + result = await env.run_scenario_evaluate("typed") + + assert result.reward == 1.0 + assert captured is not None + assert captured.content.final == "done" + assert captured.citations[0].source == "https://example.com" + @pytest.mark.asyncio async def test_submit_before_setup_raises(self) -> None: """Calling submit() before run_scenario_setup() should raise.""" diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 119769fcc..8812ce8c8 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -13,12 +13,12 @@ await ctx.call_tool("navigate", url="...") async with env("checkout", user_id="alice") as ctx: - await agent.run(ctx.prompt) + await ctx.submit("answer") # Orchestrated with Task objects tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: - await agent.run(ctx.prompt) + await ctx._run(agent) # Blank eval for manual reward async with hud.eval() as ctx: diff --git a/hud/eval/context.py b/hud/eval/context.py index f767a5bbc..05d7ce8d5 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -13,8 +13,12 @@ import logging import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Literal, Self, cast +import mcp.types as types + +from hud.agents.base import AgentContext +from hud.agents.tools.base import ToolClient from hud.environment import Environment from hud.settings import settings from hud.shared import make_request @@ -24,15 +28,17 @@ from collections.abc import Generator from types import TracebackType + from hud.agents.tools import CapabilityEntry, ToolMetadata from hud.eval.task import Task from hud.tools.types import EvaluationResult - from hud.types import MCPToolResult + from hud.types import MCPToolResult, Trace from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete logger = logging.getLogger(__name__) + # Contextvar to store current trace headers (for httpx auto-instrumentation) _current_trace_headers: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( "current_trace_headers", default=None @@ -109,7 +115,7 @@ class EvalContext(Environment): # With task (scenario sets reward automatically) tasks = load_tasks("my-org/task:1") async with hud.eval(tasks) as ctx: - await agent.run(ctx) + await ctx._run(agent) # reward set by scenario evaluate phase in __aexit__ # Blank eval (manual reward) @@ -174,7 +180,7 @@ def __init__( self.answer: str | dict[str, Any] | None = None # Agent's submitted answer self.system_prompt: str | None = None # From task.agent_config, passed to agent self.scenario_returns_schema: dict[str, Any] | None = None - self.scenario_enable_citations: bool = False + self.enable_citations: bool = False # Error tracking self.error: BaseException | None = None @@ -374,18 +380,9 @@ async def _run_task_scenario_setup(self) -> None: if prompt: self.prompt = prompt - # If scenario yielded multi-turn messages, store as conversation session = self._get_session() self.scenario_returns_schema = session.returns_schema if session else None - self.scenario_enable_citations = bool(session.enable_citations) if session else False - if session and session.prompt_messages and len(session.prompt_messages) > 1: - self.conversation = [ - { - "role": pm.role, - "content": getattr(pm.content, "text", str(pm.content)), - } - for pm in session.prompt_messages - ] + self.enable_citations = bool(session.enable_citations) if session else False async def _run_task_scenario_evaluate(self) -> None: """Run the task's scenario evaluate phase (if scenario provided).""" @@ -511,8 +508,7 @@ async def submit(self, answer: str | dict[str, Any]) -> None: Example: async with env("checkout", product="laptop") as ctx: - response = await agent.run(ctx.prompt) - await ctx.submit(response) + await ctx.submit("answer") # On exit, scenario's evaluate phase receives the answer """ if not self._task or not self._task.scenario: @@ -524,6 +520,90 @@ async def submit(self, answer: str | dict[str, Any]) -> None: # Delegate to Environment.submit() which handles storage + broadcast await super().submit(self._task.scenario, answer) + async def submit_result(self, result: Trace) -> None: + """Record an agent result on the eval context.""" + if result.isError: + error_msg = result.info.get("error") if result.info else result.content + self.error = Exception(str(error_msg)) if error_msg else Exception("Agent error") + return + + if not result.content: + return + + if result.citations: + await self.submit({"content": result.content, "citations": result.citations}) + else: + await self.submit(result.content) + + async def _run(self, agent: Any, *, max_steps: int = 10) -> Trace: + """Run an agent against this eval context.""" + await self.list_tools() + initial_messages = self.prompt_messages() + tool_client = ToolClient( + tools=self.as_tools(), + tool_handler=self.call_tool, + tool_metadata=self._tool_metadata(), + ) + + agent.enable_citations = bool(getattr(self, "enable_citations", False)) + result = await agent.run( + AgentContext( + messages=initial_messages, + tool_client=tool_client, + ), + max_steps=max_steps, + ) + await self.submit_result(result) + return result + + def _tool_metadata(self) -> ToolMetadata | None: + if environment_capabilities := self.metadata.get("environment_capabilities"): + return cast("ToolMetadata", environment_capabilities) + if capabilities := self.metadata.get("capabilities"): + return {"capabilities": cast("dict[str, str | CapabilityEntry]", capabilities)} + return None + + def prompt_messages(self) -> list[types.PromptMessage]: + """Return raw MCP prompt messages for an agent run.""" + session = self._get_session() + if session and session.prompt_messages: + return session.prompt_messages + + conversation = getattr(self, "conversation", None) + if conversation: + messages: list[types.PromptMessage] = [] + for msg in conversation: + role = cast("Literal['user', 'assistant']", msg.get("role", "user")) + messages.append( + types.PromptMessage( + role=role, + content=types.TextContent(type="text", text=msg.get("content", "")), + ) + ) + return messages + + prompt = getattr(self, "prompt", None) + if not prompt: + if self.has_scenario: + scenario = self._task.scenario if self._task else "unknown" + raise ValueError( + f"ctx.prompt is not set.\n\n" + f"Scenario '{scenario}' was specified but returned an empty prompt.\n" + f"Check that the scenario's setup function returns a non-empty string." + ) + raise ValueError( + "ctx.prompt is not set.\n\n" + "No scenario was specified in your task file.\n" + "Add a 'scenario' field to your task so scenario setup can produce a prompt." + ) + + return [ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=prompt), + ) + ] + async def _eval_enter(self) -> None: """Notify backend that eval has started.""" if not self._trace_enabled: diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 655833e2e..7b627cc4e 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -148,12 +148,12 @@ async def run_eval( env = Environment("my-env").connect_hub("browser") tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: - await agent.run(ctx.prompt) + await ctx._run(agent) # Load tasks from file or API tasks = load_tasks("hud-evals/SheetBench-50") async with hud.eval(tasks) as ctx: - await agent.run(ctx) + await ctx._run(agent) # With variants and group async with hud.eval( @@ -167,7 +167,7 @@ async def run_eval( # With concurrency limit async with hud.eval(tasks, max_concurrent=10) as ctx: - await agent.run(ctx) + await ctx._run(agent) # Access results after parallel run for e in ctx.results: diff --git a/hud/eval/task.py b/hud/eval/task.py index e13159919..fefcbec73 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -15,7 +15,7 @@ # With scenario async with env("checkout", user_id="alice") as ctx: - await agent.run(ctx.prompt) + await ctx.submit("answer") # Orchestrated via hud.eval tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] @@ -279,7 +279,7 @@ async def run( agent = create_agent(agent) async with run_eval(self, trace=trace, quiet=quiet) as ctx: - result = await agent.run(ctx, max_steps=max_steps) + result = await ctx._run(agent, max_steps=max_steps) if ctx.reward is not None: result.reward = ctx.reward diff --git a/hud/services/chat.py b/hud/services/chat.py index 50177db6c..bd53111ca 100644 --- a/hud/services/chat.py +++ b/hud/services/chat.py @@ -89,7 +89,7 @@ class Chat(AgentExecutor): Each ``send()`` call: 1. Appends the user message to history 2. Creates a Task copy with the full history as scenario args - 3. Runs ``hud.eval(task)`` -> scenario setup -> ``agent.run(ctx)`` -> evaluate + 3. Runs ``hud.eval(task)`` -> scenario setup -> ``ctx._run(agent)`` -> evaluate 4. Appends the assistant response to history 5. Returns the Trace diff --git a/hud/tests/public_api/test_v5_legacy_aliases.py b/hud/tests/public_api/test_v5_legacy_aliases.py index 8e94cc281..ea8f3e633 100644 --- a/hud/tests/public_api/test_v5_legacy_aliases.py +++ b/hud/tests/public_api/test_v5_legacy_aliases.py @@ -54,12 +54,6 @@ def fake_load_tasks(source: str, *, raw: bool = False) -> list[dict[str, str]]: assert calls == [("local-or-remote-source", True)] -def test_agent_response_aliases_inference_result() -> None: - import hud.types as types - - assert types.AgentResponse is types.InferenceResult - - def test_tool_router_aliases_environment_mcp_router() -> None: import hud.environment as environment diff --git a/hud/tests/public_api/test_v5_surface_imports.py b/hud/tests/public_api/test_v5_surface_imports.py index 15cf0f43f..57b6ab8d6 100644 --- a/hud/tests/public_api/test_v5_surface_imports.py +++ b/hud/tests/public_api/test_v5_surface_imports.py @@ -81,8 +81,8 @@ "PlaywrightTool", ), "hud.types": ( + "AgentResponse", "AgentType", - "InferenceResult", "MCPToolCall", "MCPToolResult", "Trace", @@ -97,12 +97,7 @@ "OpenAIChatAgent", "create_agent", ), - "hud.agents.claude": ( - "ClaudeAgent", - "base64_to_content_block", - "text_to_content_block", - "tool_use_content_block", - ), + "hud.agents.claude": ("ClaudeAgent",), "hud.datasets": ( "display_results", "load_tasks", @@ -215,7 +210,6 @@ "hud.tools.agent": ("AgentTool",), "hud.agents.gemini": ("GeminiAgent",), "hud.agents.openai": ("OpenAIAgent",), - "hud.agents.openai_chat": ("OpenAIChatAgent",), "hud.tools.coding": ( "ApplyPatchTool", "BashTool", diff --git a/hud/tests/public_api/test_v5_workflow_contracts.py b/hud/tests/public_api/test_v5_workflow_contracts.py index cd9df4819..2491baba8 100644 --- a/hud/tests/public_api/test_v5_workflow_contracts.py +++ b/hud/tests/public_api/test_v5_workflow_contracts.py @@ -650,7 +650,7 @@ def test_native_grader_helpers_keep_basic_semantics() -> None: assert f1_score("hello hud", "hello sdk") == 0.5 -def test_eval_context_user_facing_properties_and_tool_surface() -> None: +def test_eval_context_user_facing_properties_and_tool_helpers() -> None: ctx = EvalContext(trace=False, quiet=True, variants={"model": "test"}) ctx.prompt = "Do the task" diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 8ddf6a9fc..d870c5a62 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -39,6 +39,7 @@ async def test_run_dataset_with_task_list(self): mock_ctx = AsyncMock() mock_ctx.results = None mock_ctx.reward = None + mock_ctx._run.return_value = Trace(reward=1.0, done=True) # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() @@ -57,7 +58,7 @@ async def test_run_dataset_with_task_list(self): # Should return list with ctx assert len(results) == 1 - mock_agent_instance.run.assert_called_once() + mock_ctx._run.assert_called_once_with(mock_agent_instance, max_steps=5) @pytest.mark.asyncio async def test_run_dataset_from_source_string(self): @@ -70,6 +71,7 @@ async def test_run_dataset_from_source_string(self): mock_ctx = AsyncMock() mock_ctx.results = None + mock_ctx._run.return_value = Trace(reward=1.0, done=True) # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() @@ -101,6 +103,7 @@ async def test_run_dataset_passes_parameters(self): mock_ctx = AsyncMock() mock_ctx.results = None + mock_ctx._run.return_value = Trace(reward=1.0, done=True) # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() diff --git a/hud/tests/test_types.py b/hud/tests/test_types.py index 55a5c0f89..bc1147ffe 100644 --- a/hud/tests/test_types.py +++ b/hud/tests/test_types.py @@ -4,7 +4,7 @@ from mcp.types import ImageContent, TextContent -from hud.types import InferenceResult, MCPToolCall, MCPToolResult, Trace, TraceStep +from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace, TraceStep def test_mcp_tool_call_str_long_args(): @@ -164,17 +164,17 @@ def test_mcp_tool_result_rich(): mock_console.format_tool_result.assert_called_once() -def test_inference_result_str_with_reasoning(): - """Test InferenceResult __str__ includes reasoning.""" - response = InferenceResult(reasoning="Test reasoning", content="Test content") +def test_agent_response_str_with_reasoning(): + """Test AgentResponse __str__ includes reasoning.""" + response = AgentResponse(reasoning="Test reasoning", content="Test content") output = str(response) assert "Reasoning: Test reasoning" in output assert "Content: Test content" in output -def test_inference_result_str_with_tool_calls(): - """Test InferenceResult __str__ includes tool calls.""" - response = InferenceResult( +def test_agent_response_str_with_tool_calls(): + """Test AgentResponse __str__ includes tool calls.""" + response = AgentResponse( tool_calls=[ MCPToolCall(name="tool1", arguments={"a": 1}), MCPToolCall(name="tool2", arguments={"b": 2}), @@ -186,38 +186,29 @@ def test_inference_result_str_with_tool_calls(): assert "tool2" in output -def test_inference_result_str_with_raw(): - """Test InferenceResult __str__ includes raw.""" - response = InferenceResult(raw={"raw_data": "value"}) +def test_agent_response_str_with_raw(): + """Test AgentResponse __str__ includes raw.""" + response = AgentResponse(raw={"raw_data": "value"}) output = str(response) assert "Raw:" in output -def test_inference_result_citations_default_empty(): - """InferenceResult.citations defaults to empty list.""" - result = InferenceResult(content="hello") +def test_agent_response_citations_default_empty(): + """AgentResponse.citations defaults to empty list.""" + result = AgentResponse(content="hello") assert result.citations == [] -def test_inference_result_citations_roundtrip(): +def test_agent_response_citations_roundtrip(): """Citations survive serialize/deserialize.""" cit = {"type": "url_citation", "source": "https://example.com", "title": "Example"} - result = InferenceResult(content="hello", citations=[cit]) + result = AgentResponse(content="hello", citations=[cit]) data = result.model_dump(mode="json") - restored = InferenceResult(**data) + restored = AgentResponse(**data) assert len(restored.citations) == 1 assert restored.citations[0]["source"] == "https://example.com" -def test_agent_response_alias(): - """AgentResponse is a backwards-compatible alias for InferenceResult.""" - from hud.types import AgentResponse - - assert AgentResponse is InferenceResult - r = AgentResponse(content="test", done=True) - assert isinstance(r, InferenceResult) - - def test_trace_citations_default_empty(): """Trace.citations defaults to empty list.""" trace = Trace() diff --git a/hud/tools/agent.py b/hud/tools/agent.py index 0d8743fa4..dd5646015 100644 --- a/hud/tools/agent.py +++ b/hud/tools/agent.py @@ -216,7 +216,7 @@ async def _run_subagent() -> ToolResult: else: agent = self._agent_cls.create(**self._agent_params) # type: ignore - result = await agent.run(ctx) + result = await ctx._run(agent) content = result.content if hasattr(result, "content") and result.content else "" return ToolResult(content=[TextContent(type="text", text=content)]) diff --git a/hud/tools/computer/base.py b/hud/tools/computer/base.py index 9dbe2a27d..cf6770012 100644 --- a/hud/tools/computer/base.py +++ b/hud/tools/computer/base.py @@ -88,12 +88,14 @@ def __init__( self.height = height or computer_settings.DISPLAY_HEIGHT # Build metadata with resolution info - meta = { + meta: dict[str, object] = { "resolution": { "width": self.width, "height": self.height, } } + if coordinate_space is not None: + meta["coordinate_space"] = coordinate_space # Initialize base tool with executor as env super().__init__( diff --git a/hud/tools/computer/settings.py b/hud/tools/computer/settings.py index 8d3121500..51a0201ce 100644 --- a/hud/tools/computer/settings.py +++ b/hud/tools/computer/settings.py @@ -93,11 +93,6 @@ class ComputerSettings(BaseSettings): description="Whether to rescale images to the agent width and height", validation_alias="GEMINI_RESCALE_IMAGES", ) - GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS: int = Field( - default=3, - description="Maximum number of recent turns to keep screenshots for in Gemini agent", - validation_alias="GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS", - ) GLM_COMPUTER_WIDTH: int = Field( default=1024, description="Width of the display to use for the z-ai/glm4.5v computer tools", diff --git a/hud/tools/tests/test_agent_tool.py b/hud/tools/tests/test_agent_tool.py index de8196c38..d85523801 100644 --- a/hud/tools/tests/test_agent_tool.py +++ b/hud/tools/tests/test_agent_tool.py @@ -1,220 +1,64 @@ -"""Tests for AgentTool - scenario-to-agent composition.""" +"""Tests for AgentTool's public tool schema behavior.""" from __future__ import annotations -import inspect -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest from hud.environment import Environment from hud.eval.task import Task -from hud.tools.agent import AgentTool, _is_eval_only - - -class TestIsEvalOnly: - """Tests for _is_eval_only helper function.""" - - def test_required_param_not_eval_only(self) -> None: - """Required params (no default) are not eval-only.""" - - def fn(x: str) -> None: - pass - - sig = inspect.signature(fn) - param = sig.parameters["x"] - assert not _is_eval_only(param) - - def test_optional_with_value_not_eval_only(self) -> None: - """Optional params with non-None default are not eval-only.""" - - def fn(x: str = "default") -> None: - pass - - sig = inspect.signature(fn) - param = sig.parameters["x"] - assert not _is_eval_only(param) - - def test_optional_none_without_union_not_eval_only(self) -> None: - """Optional with None default but no None in type is not eval-only.""" - - def fn(x: str = None) -> None: # type: ignore[assignment] # noqa: RUF013 - pass - - sig = inspect.signature(fn) - param = sig.parameters["x"] - assert not _is_eval_only(param) - - def test_optional_none_with_union_is_eval_only(self) -> None: - """Params with `X | None = None` pattern are eval-only.""" - - def fn(x: str | None = None) -> None: - pass - - sig = inspect.signature(fn) - param = sig.parameters["x"] - assert _is_eval_only(param) - - def test_optional_int_none_is_eval_only(self) -> None: - """Works with int | None = None too.""" - - def fn(x: int | None = None) -> None: - pass - - sig = inspect.signature(fn) - param = sig.parameters["x"] - assert _is_eval_only(param) - - def test_string_annotation_with_none_union(self) -> None: - """Handles string annotations like 'str | None'.""" - # Simulate string annotation - param = inspect.Parameter( - "x", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=None, - annotation="str | None", - ) - assert _is_eval_only(param) - - def test_string_annotation_without_none(self) -> None: - """String annotations without None are not eval-only.""" - param = inspect.Parameter( - "x", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=None, - annotation="str", - ) - assert not _is_eval_only(param) +from hud.tools.agent import AgentTool class TestAgentToolInit: - """Tests for AgentTool initialization.""" - def test_requires_model_or_agent(self) -> None: - """Must provide either model or agent.""" task = Task(args={}) with pytest.raises(ValueError, match="Must provide either"): AgentTool(task) def test_cannot_provide_both_model_and_agent(self) -> None: - """Cannot provide both model and agent.""" task = Task(args={}) mock_agent = MagicMock() with pytest.raises(ValueError, match="Cannot provide both"): AgentTool(task, model="claude", agent=mock_agent) # type: ignore[arg-type] - def test_accepts_model_string(self) -> None: - """Can create with model string.""" - task = Task(scenario="test", args={}) - tool = AgentTool(task, model="claude") - - assert tool._model == "claude" - assert tool._agent_cls is None - - def test_accepts_agent_class(self) -> None: - """Can create with custom agent class.""" - task = Task(scenario="test", args={}) - mock_agent_cls = MagicMock() - tool = AgentTool(task, agent=mock_agent_cls) # type: ignore[arg-type] - - assert tool._model is None - assert tool._agent_cls is mock_agent_cls - def test_name_defaults_to_scenario(self) -> None: - """Tool name defaults to scenario name.""" task = Task(scenario="investigate", args={}) tool = AgentTool(task, model="claude") assert tool.name == "investigate" def test_name_can_be_overridden(self) -> None: - """Tool name can be overridden.""" task = Task(scenario="investigate", args={}) tool = AgentTool(task, model="claude", name="custom_name") assert tool.name == "custom_name" -class TestAgentToolParamFiltering: - """Tests for parameter filtering (eval-only params hidden).""" - - def test_filters_eval_only_params(self) -> None: - """Eval-only params (| None = None) are filtered from visible_params.""" - env = Environment("test") - - # Use Union syntax for consistency across Python versions - @env.scenario() - async def investigate( - issue_id: str, - include_traces: bool = True, - expected_cause: str | None = None, # Eval only - ): - yield {"task": f"Investigate {issue_id}"} - - task = env("investigate") - tool = AgentTool(task, model="claude") - - # visible_params should only have issue_id and include_traces - assert "issue_id" in tool._visible_params - assert "include_traces" in tool._visible_params - assert "expected_cause" not in tool._visible_params - - def test_all_required_params_visible(self) -> None: - """All required params are visible.""" - env = Environment("test") - - @env.scenario() - async def search(query: str, limit: int): - yield {"task": f"Search: {query}"} - - task = env("search") - tool = AgentTool(task, model="claude") - - assert "query" in tool._visible_params - assert "limit" in tool._visible_params - - def test_optional_with_default_visible(self) -> None: - """Optional params with non-None defaults are visible.""" - env = Environment("test") - - @env.scenario() - async def fetch(url: str, request_timeout: int = 30, retries: int = 3): - yield {"task": f"Fetch {url}"} - - task = env("fetch") - tool = AgentTool(task, model="claude") - - assert "url" in tool._visible_params - assert "request_timeout" in tool._visible_params - assert "retries" in tool._visible_params - - -class TestAgentToolSchema: - """Tests for JSON schema generation.""" - - def test_builds_json_schema(self) -> None: - """Builds proper JSON schema from visible params.""" +class TestAgentToolMCP: + def test_mcp_tool_exposes_required_and_defaulted_scenario_parameters(self) -> None: env = Environment("test") @env.scenario() - async def investigate(issue_id: str, verbose: bool = False): - yield {"task": f"Investigate {issue_id}"} + async def investigate(issue_id: str, verbose: bool = False, limit: int = 10): + yield {"task": f"Investigate {issue_id} {verbose} {limit}"} task = env("investigate") tool = AgentTool(task, model="claude") - schema = tool._param_schema - assert schema is not None + schema = tool.mcp.parameters assert schema["type"] == "object" - assert "issue_id" in schema["properties"] - assert "verbose" in schema["properties"] + assert set(schema["properties"]) == {"issue_id", "verbose", "limit"} assert "issue_id" in schema["required"] assert "verbose" not in schema["required"] # Has default + assert "limit" not in schema["required"] + assert schema["properties"]["verbose"]["default"] is False + assert schema["properties"]["limit"]["default"] == 10 - def test_schema_excludes_eval_only(self) -> None: - """Schema excludes eval-only params.""" + def test_mcp_tool_hides_eval_only_parameters(self) -> None: env = Environment("test") @env.scenario() @@ -227,17 +71,11 @@ async def check( task = env("check") tool = AgentTool(task, model="claude") - schema = tool._param_schema - assert schema is not None + schema = tool.mcp.parameters assert "item_id" in schema["properties"] assert "expected_status" not in schema["properties"] - -class TestAgentToolMCP: - """Tests for MCP tool integration.""" - def test_mcp_property_returns_tool(self) -> None: - """The mcp property returns a FastMCP FunctionTool.""" from fastmcp.tools import FunctionTool env = Environment("test") @@ -251,105 +89,3 @@ async def greet(name: str): mcp_tool = tool.mcp assert isinstance(mcp_tool, FunctionTool) - - def test_mcp_has_filtered_parameters(self) -> None: - """MCP tool has filtered parameter schema.""" - env = Environment("test") - - @env.scenario() - async def analyze( - data: str, - expected_result: str | None = None, # Eval only - ): - yield {"task": f"Analyze {data}"} - - task = env("analyze") - tool = AgentTool(task, model="claude") - - mcp_tool = tool.mcp - params = mcp_tool.parameters # FunctionTool uses 'parameters' - - assert "data" in params["properties"] - assert "expected_result" not in params["properties"] - - -class TestAgentToolCall: - """Tests for AgentTool.__call__.""" - - @pytest.mark.asyncio - async def test_filters_kwargs_to_visible_only(self) -> None: - """Call filters kwargs to visible params only.""" - # Import modules first so patches work - import hud.agents - import hud.eval.manager # noqa: F401 - - env = Environment("test") - - @env.scenario() - async def process(item: str, expected: str | None = None): - yield {"task": f"Process {item}"} - - task = env("process") - tool = AgentTool(task, model="claude") - - # Mock the eval context and agent - with ( - patch("hud.eval.manager.run_eval") as mock_run_eval, - patch("hud.agents.create_agent") as mock_create_agent, - ): - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_run_eval.return_value = mock_ctx - - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock(content="result")) - mock_create_agent.return_value = mock_agent - - # Call with both visible and eval-only params - await tool(item="test", expected="should_be_filtered") - - # Check that task was created with filtered args - call_args = mock_run_eval.call_args - task_arg = call_args[0][0] - assert "item" in task_arg.args - assert "expected" not in task_arg.args # Filtered out - - @pytest.mark.asyncio - async def test_merges_template_args(self) -> None: - """Call merges kwargs with template args.""" - # Import modules first so patches work - import hud.agents - import hud.eval.manager # noqa: F401 - - env = Environment("test") - - @env.scenario() - async def search(query: str, limit: int = 10): - yield {"task": f"Search {query}"} - - # Create template with some args pre-filled - task = env("search", limit=5) - tool = AgentTool(task, model="claude") - - with ( - patch("hud.eval.manager.run_eval") as mock_run_eval, - patch("hud.agents.create_agent") as mock_create_agent, - ): - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_run_eval.return_value = mock_ctx - - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock(content="result")) - mock_create_agent.return_value = mock_agent - - # Call with additional arg - await tool(query="test query") - - # Check merged args - call_args = mock_run_eval.call_args - task_arg = call_args[0][0] - assert task_arg.args["query"] == "test query" - assert task_arg.args["limit"] == 5 # From template diff --git a/hud/tools/tests/test_coding_apply_patch.py b/hud/tools/tests/test_coding_apply_patch.py index 1008c831d..e959dd5cc 100644 --- a/hud/tools/tests/test_coding_apply_patch.py +++ b/hud/tools/tests/test_coding_apply_patch.py @@ -1,4 +1,4 @@ -"""Tests for apply_patch compatibility tool and patch parser helpers.""" +"""Tests for the legacy apply_patch compatibility wrapper.""" from __future__ import annotations @@ -8,15 +8,6 @@ import pytest from mcp.types import TextContent -from hud.agents.openai.tools.apply_patch import ( - ActionType, - DiffError, - Parser, - _apply_commit, - _identify_files_needed, - _patch_to_commit, - _text_to_patch, -) from hud.tools._legacy import ApplyPatchTool from hud.tools.coding import EditTool @@ -42,56 +33,3 @@ async def test_update_file_uses_edit_tool_behavior(self): assert file_path.read_text() == "new\n" assert isinstance(result[0], TextContent) assert "written successfully" in result[0].text - - -class TestPatchParser: - """Focused tests for shared V4A parser helpers used by EditTool.""" - - def test_parse_add_file(self): - lines = [ - "*** Begin Patch", - "*** Add File: new.txt", - "+line 1", - "+line 2", - "*** End Patch", - ] - parser = Parser(current_files={}, lines=lines, index=1) - parser.parse() - - action = parser.patch.actions["new.txt"] - assert action.type == ActionType.ADD - assert action.new_file == "line 1\nline 2" - - def test_parse_update_file(self): - text = "*** Begin Patch\n*** Update File: test.txt\n@@\n-old\n+new\n*** End Patch" - - patch, fuzz = _text_to_patch(text, {"test.txt": "old\n"}) - - assert fuzz == 0 - action = patch.actions["test.txt"] - assert action.type == ActionType.UPDATE - - def test_identify_files_needed(self): - text = "*** Begin Patch\n*** Update File: a.txt\n@@\n-old\n+new\n*** End Patch" - assert _identify_files_needed(text) == ["a.txt"] - - def test_apply_commit_update(self): - patch, _ = _text_to_patch( - "*** Begin Patch\n*** Update File: a.txt\n@@\n-old\n+new\n*** End Patch", - {"a.txt": "old\n"}, - ) - commit = _patch_to_commit(patch, {"a.txt": "old\n"}) - files = {"a.txt": "old\n"} - - def write(path: str, content: str | None) -> None: - files[path] = content or "" - - def remove(path: str) -> None: - del files[path] - - _apply_commit(commit, write, remove) - assert files["a.txt"] == "new\n" - - def test_invalid_patch_raises(self): - with pytest.raises(DiffError): - _text_to_patch("not a patch", {}) diff --git a/hud/tools/tests/test_computer.py b/hud/tools/tests/test_computer.py index b4d4c1c7c..4e2fce3d3 100644 --- a/hud/tools/tests/test_computer.py +++ b/hud/tools/tests/test_computer.py @@ -175,6 +175,7 @@ def test_glm_computer_is_legacy_generic_registration(): assert comp.name == "glm_computer" assert "native_tools" not in comp.meta + assert comp.meta["coordinate_space"] == 999 @pytest.mark.asyncio diff --git a/hud/types.py b/hud/types.py index 7dd2e07ab..ae2d18b52 100644 --- a/hud/types.py +++ b/hud/types.py @@ -3,12 +3,29 @@ import json import uuid from enum import Enum -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import mcp.types as types from mcp.types import CallToolRequestParams, CallToolResult from pydantic import BaseModel, ConfigDict, Field +if TYPE_CHECKING: + from hud.agents.claude import ClaudeAgent + from hud.agents.gemini import GeminiAgent + from hud.agents.openai import OpenAIAgent + from hud.agents.openai_compatible import OpenAIChatAgent + from hud.agents.types import ClaudeConfig, GeminiConfig, OpenAIChatConfig, OpenAIConfig + + AgentClass: TypeAlias = type[ClaudeAgent | GeminiAgent | OpenAIAgent | OpenAIChatAgent] + AgentConfigClass: TypeAlias = type[ + ClaudeConfig | GeminiConfig | OpenAIConfig | OpenAIChatConfig + ] + _AgentTypeInfo: TypeAlias = tuple[AgentClass, AgentConfigClass, str] + +# JSON-compatible scalar/container values. +JsonValue: TypeAlias = str | int | float | bool | None | list["JsonValue"] | dict[str, "JsonValue"] +JsonObject: TypeAlias = dict[str, JsonValue] + class AgentType(str, Enum): CLAUDE = "claude" @@ -17,29 +34,25 @@ class AgentType(str, Enum): OPENAI_COMPATIBLE = "openai_compatible" @property - def cls(self) -> type: - if self == AgentType.CLAUDE: - from hud.agents.claude import ClaudeAgent - - return ClaudeAgent - elif self == AgentType.OPENAI: - from hud.agents import OpenAIAgent + def cls(self) -> AgentClass: + return self._info[0] - return OpenAIAgent - elif self == AgentType.GEMINI: - from hud.agents.gemini import GeminiAgent - - return GeminiAgent - elif self == AgentType.OPENAI_COMPATIBLE: - from hud.agents.openai_compatible import OpenAIChatAgent + @property + def config_cls(self) -> AgentConfigClass: + """Get config class without importing agent (avoids SDK dependency).""" + return self._info[1] - return OpenAIChatAgent - else: - raise ValueError(f"Unsupported agent type: {self}") + @property + def gateway_provider(self) -> str: + """Default provider client used when this agent type is a gateway shortcut.""" + return self._info[2] @property - def config_cls(self) -> type: - """Get config class without importing agent (avoids SDK dependency).""" + def _info(self) -> _AgentTypeInfo: + from hud.agents import OpenAIAgent + from hud.agents.claude import ClaudeAgent + from hud.agents.gemini import GeminiAgent + from hud.agents.openai_compatible import OpenAIChatAgent from hud.agents.types import ( ClaudeConfig, GeminiConfig, @@ -47,24 +60,15 @@ def config_cls(self) -> type: OpenAIConfig, ) - mapping: dict[AgentType, type] = { - AgentType.CLAUDE: ClaudeConfig, - AgentType.OPENAI: OpenAIConfig, - AgentType.GEMINI: GeminiConfig, - AgentType.OPENAI_COMPATIBLE: OpenAIChatConfig, - } - if self not in mapping: - raise ValueError(f"Unsupported agent type for config: {self}") - return mapping[self] - - -class BaseAgentConfig(BaseModel): - """Agent configuration for LLM-specific settings.""" - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True) - - system_prompt: str | None = None - hosted_tools: list[Any] = Field(default_factory=list) + match self: + case AgentType.CLAUDE: + return ClaudeAgent, ClaudeConfig, "anthropic" + case AgentType.OPENAI: + return OpenAIAgent, OpenAIConfig, "openai" + case AgentType.GEMINI: + return GeminiAgent, GeminiConfig, "gemini" + case AgentType.OPENAI_COMPATIBLE: + return OpenAIChatAgent, OpenAIChatConfig, "openai" class MCPToolCall(CallToolRequestParams): @@ -72,6 +76,7 @@ class MCPToolCall(CallToolRequestParams): id: str = Field(default_factory=lambda: str(uuid.uuid4())) # Unique identifier for reference annotation: str | None = None # Optional explanation of why this action is taken + provider_name: str | None = None # Original provider tool name when it differs from MCP name def __str__(self) -> str: """Format tool call as plain text.""" @@ -149,8 +154,8 @@ def __rich__(self) -> str: return hud_console.format_tool_result(content_summary, self.isError) -class InferenceResult(BaseModel): - """Result of a single LLM inference call. +class AgentResponse(BaseModel): + """Result of a single agent inference call. Returned by provider agents' ``get_response()`` methods. Carries the model's text output, any tool calls it wants to make, and provider- @@ -171,7 +176,7 @@ class InferenceResult(BaseModel): # --- RESPONSE METADATA --- # Populated by provider agents when citations are available. - # Uses dict form of Citation (provider-normalized) so InferenceResult + # Uses dict form of Citation (provider-normalized) so AgentResponse # doesn't depend on hud.tools.types at import time. citations: list[dict[str, Any]] = Field(default_factory=list) @@ -194,10 +199,6 @@ def __str__(self) -> str: return response -# Backwards-compatible alias (deprecated — use InferenceResult) -AgentResponse = InferenceResult - - class TraceStep(BaseModel): """Canonical data for a single span (shared with telemetry).""" @@ -262,7 +263,7 @@ class Trace(BaseModel): content: str | None = Field(default=None) isError: bool = Field(default=False) - # Response metadata carried from the final InferenceResult + # Response metadata carried from the final AgentResponse citations: list[dict[str, Any]] = Field(default_factory=list) # Metadata @@ -296,7 +297,8 @@ def append(self, step: TraceStep) -> None: "AgentResponse", "AgentType", "HudSpan", - "InferenceResult", + "JsonObject", + "JsonValue", "MCPToolCall", "MCPToolResult", "Task", diff --git a/hud/utils/hud_console.py b/hud/utils/hud_console.py index 041ea6753..17f526aec 100644 --- a/hud/utils/hud_console.py +++ b/hud/utils/hud_console.py @@ -621,81 +621,6 @@ def note(self, message: str, stderr: bool = True) -> None: """Print an important note with asterism symbol.""" self.symbol(Symbols.ITEM, message, GOLD, stderr) - # ------------------------------------------------------------------ - # Agent-facing display methods - # ------------------------------------------------------------------ - - def format_tool_discovery( - self, - tools: list[Any], - skipped: list[tuple[Any, str]] | None = None, - stderr: bool = True, - ) -> None: - """Display a table of discovered tools on agent initialization. - - Args: - tools: All available MCP tools - skipped: List of (tool, reason) for skipped tools - stderr: Output to stderr (default True) - """ - console = self._stderr_console if stderr else self._stdout_console - - table = Table( - show_header=True, - box=None, - padding=(0, 1), - title=f"[{GOLD}]Discovered {len(tools)} tools[/{GOLD}]", - title_style="", - ) - table.add_column("Tool", style=TEXT, no_wrap=True) - table.add_column("Available", style=DIM) - - for tool in tools: - name = tool.name if hasattr(tool, "name") else str(tool) - table.add_row(name, f"[{GREEN}]yes[/{GREEN}]") - - console.print(table) - - if skipped: - for tool, reason in skipped: - name = tool.name if hasattr(tool, "name") else str(tool) - console.print(f" [{DIM}]⊘ {escape(name)}: {escape(reason)}[/{DIM}]") - - def format_step( - self, - step: int, - max_steps: int, - tool_calls: list[Any], - tool_results: list[Any], - elapsed: float | None = None, - stderr: bool = True, - ) -> None: - """Display a compact step summary after tool execution. - - Args: - step: Current step number - max_steps: Maximum steps (-1 for unlimited) - tool_calls: List of MCPToolCall objects - tool_results: List of MCPToolResult objects - elapsed: Step duration in seconds - stderr: Output to stderr (default True) - """ - console = self._stderr_console if stderr else self._stdout_console - - step_label = f"Step {step}" - if max_steps != -1: - step_label += f"/{max_steps}" - if elapsed is not None: - step_label += f" [{elapsed:.1f}s]" - - console.print(f"\n[bold {GOLD}]{step_label}[/bold {GOLD}]") - - for call, result in zip(tool_calls, tool_results, strict=False): - call_str = str(call) if hasattr(call, "__rich__") else repr(call) - result_str = str(result) if hasattr(result, "__rich__") else repr(result) - console.print(f" {call_str}") - console.print(f" {result_str}") - # Global design instance for convenience class _ProgressContext: diff --git a/pyproject.toml b/pyproject.toml index 17ce20c36..dd20a3664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -212,6 +212,7 @@ exclude = [ ] pythonVersion = "3.11" typeCheckingMode = "basic" +strict = ["hud/agents"] reportMissingImports = "warning" [tool.coverage.run] @@ -248,4 +249,4 @@ testpaths = ["hud", "examples"] addopts = "" markers = [ "integration: marks tests as integration tests (require HUD_API_KEY, network access)", -] +] \ No newline at end of file