From 434f8af9c4f23c3f195078516b8c02a500627195 Mon Sep 17 00:00:00 2001 From: Achuth Narayan Rajagopal Date: Thu, 21 May 2026 21:49:48 +0000 Subject: [PATCH] feat(telemetry): Fix error message telemetry for tool calls The tool error details are not populated correctly in traces for some some tool call categories like REST Tool or MCP Tool. --- src/google/adk/flows/llm_flows/functions.py | 58 +++- src/google/adk/telemetry/_instrumentation.py | 2 + src/google/adk/telemetry/tracing.py | 7 + src/google/adk/tools/bash_tool.py | 6 + .../adk/tools/discovery_engine_search_tool.py | 6 + src/google/adk/tools/environment/_tools.py | 24 ++ src/google/adk/tools/function_tool.py | 6 + src/google/adk/tools/google_tool.py | 6 + src/google/adk/tools/mcp_tool/mcp_tool.py | 6 + .../openapi_spec_parser/rest_api_tool.py | 6 + src/google/adk/tools/skill_toolset.py | 22 ++ .../flows/llm_flows/test_functions_simple.py | 265 ++++++++++++++++++ tests/unittests/telemetry/test_spans.py | 242 ++++++++++++++++ 13 files changed, 654 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 682a53fc94..259d40b6b6 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -68,6 +68,48 @@ _TOOL_THREAD_POOL_LOCK = threading.Lock() +def _detect_error_type_for_telemetry( + tool: BaseTool, + tool_context: ToolContext, + function_response: Any, +) -> Optional[str]: + """Detects an error type from a tool response for telemetry purposes. + + This does not modify the response. `_detect_error_in_response` is an optional + per-tool hook accessed via `getattr` to avoid adding a public API on + `BaseTool`. Any exception raised by the detector is logged and swallowed so + that telemetry logic never breaks tool execution. + + Args: + tool: The tool whose response is being inspected. + tool_context: The tool context for the current invocation. Detection is + skipped when the tool is requesting auth or confirmation. + function_response: The raw response returned by the tool. + + Returns: + The error type string reported by the tool's `_detect_error_in_response` + hook, or `None` if no error was detected, no hook is defined, or the hook + raised an exception. + """ + try: + if ( + tool_context.actions.requested_auth_configs + or tool_context.actions.requested_tool_confirmations + ): + return None + detector = getattr(tool, '_detect_error_in_response', None) + if detector is None: + return None + return detector(function_response) + except Exception: # pylint: disable=broad-exception-caught + # Never let telemetry logic break tool execution. + logger.exception( + 'Error while detecting error type for telemetry from tool %r.', + getattr(tool, 'name', tool), + ) + return None + + def _is_live_request_queue_annotation(param: inspect.Parameter) -> bool: """Check whether a parameter is annotated as LiveRequestQueue. @@ -482,6 +524,7 @@ async def _run_on_tool_error_callbacks( function_args = ( copy.deepcopy(function_call.args) if function_call.args else {} ) + detected_error_type: Optional[str] = None tool_context = _create_tool_context( invocation_context, function_call, tool_confirmation @@ -505,7 +548,7 @@ async def _run_on_tool_error_callbacks( raise tool_error async def _run_with_trace(): - nonlocal function_args + nonlocal function_args, detected_error_type # Step 1: Check if plugin before_tool_callback overrides the function # response. @@ -586,6 +629,10 @@ async def _run_with_trace(): # the tool returned nothing. return None + detected_error_type = _detect_error_type_for_telemetry( + tool, tool_context, function_response + ) + # Note: State deltas are not applied here - they are collected in # tool_context.actions.state_delta and applied later when the session # service processes the events @@ -600,6 +647,7 @@ async def _run_with_trace(): tool, agent, function_args ) as tel_ctx: tel_ctx.function_response_event = await _run_with_trace() + tel_ctx.error_type = detected_error_type return tel_ctx.function_response_event @@ -718,6 +766,7 @@ async def _run_on_tool_error_callbacks( function_args = ( copy.deepcopy(function_call.args) if function_call.args else {} ) + detected_error_type: Optional[str] = None tool_context = _create_tool_context(invocation_context, function_call) @@ -738,7 +787,7 @@ async def _run_on_tool_error_callbacks( raise tool_error async def _run_with_trace(): - nonlocal function_args + nonlocal function_args, detected_error_type # Do not use "args" as the variable name, because it is a reserved keyword # in python debugger. @@ -827,6 +876,10 @@ async def _run_with_trace(): # build when the tool returned nothing. return None + detected_error_type = _detect_error_type_for_telemetry( + tool, tool_context, function_response + ) + # Note: State deltas are not applied here - they are collected in # tool_context.actions.state_delta and applied later when the session # service processes the events @@ -841,6 +894,7 @@ async def _run_with_trace(): tool, agent, function_args ) as tel_ctx: tel_ctx.function_response_event = await _run_with_trace() + tel_ctx.error_type = detected_error_type return tel_ctx.function_response_event diff --git a/src/google/adk/telemetry/_instrumentation.py b/src/google/adk/telemetry/_instrumentation.py index 975b553c7a..1a3e6a8c38 100644 --- a/src/google/adk/telemetry/_instrumentation.py +++ b/src/google/adk/telemetry/_instrumentation.py @@ -68,6 +68,7 @@ class TelemetryContext: otel_context: context_api.Context function_response_event: event_lib.Event | None = None + error_type: str | None = None def _record_agent_metrics( @@ -148,6 +149,7 @@ async def record_tool_execution( args=function_args, function_response_event=response_event, error=caught_error, + error_type=tel_ctx.error_type, ) finally: try: diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 7b813be16d..5b0e6b477d 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -171,6 +171,7 @@ def trace_tool_call( function_response_event: Event | None, error: Exception | None = None, span: Span | None = None, + error_type: str | None = None, ): """Traces tool call. @@ -180,6 +181,10 @@ def trace_tool_call( function_response_event: The event with the function response details. error: The exception raised during tool execution, if any. span: The span to record attributes on. If None, uses current span. + error_type: An error type string detected from the tool's response dict + (e.g., "HTTP_ERROR", "MCP_TOOL_ERROR"). Used when the tool returned an + error as a dict rather than raising an exception. Ignored if `error` is + also set (exception takes precedence). """ span = span or trace.get_current_span() @@ -196,6 +201,8 @@ def trace_tool_call( span.set_attribute(ERROR_TYPE, str(error.error_type)) else: span.set_attribute(ERROR_TYPE, type(error).__name__) + elif error_type is not None: + span.set_attribute(ERROR_TYPE, error_type) # Special case for client side association with a remote tool call if ( diff --git a/src/google/adk/tools/bash_tool.py b/src/google/adk/tools/bash_tool.py index 89de3bf34f..fa1925205d 100644 --- a/src/google/adk/tools/bash_tool.py +++ b/src/google/adk/tools/bash_tool.py @@ -247,3 +247,9 @@ async def run_async( "stdout": stdout_res, "stderr": stderr_res, } + + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("error"): + return "TOOL_ERROR" + return None diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index 54603e6a83..eea843c35f 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -220,6 +220,12 @@ def discovery_engine_search( except GoogleAPICallError as e: return {'status': 'error', 'error_message': str(e)} + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('status') == 'error': + return 'TOOL_ERROR' + return None + def _do_search( self, query: str, diff --git a/src/google/adk/tools/environment/_tools.py b/src/google/adk/tools/environment/_tools.py index 67baa8f40d..f62e8e2729 100644 --- a/src/google/adk/tools/environment/_tools.py +++ b/src/google/adk/tools/environment/_tools.py @@ -136,6 +136,12 @@ async def run_async( result['error'] = f'Command timed out after {DEFAULT_TIMEOUT}s.' return result + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('status') == 'error': + return 'TOOL_ERROR' + return None + @experimental class ReadFileTool(BaseTool): @@ -260,6 +266,12 @@ async def run_async( except Exception as e: return {'status': 'error', 'error': str(e)} + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('status') == 'error': + return 'TOOL_ERROR' + return None + @experimental class WriteFileTool(BaseTool): @@ -311,6 +323,12 @@ async def run_async( return {'status': 'error', 'error': str(e)} return {'status': 'ok', 'message': f'Wrote {path}'} + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('status') == 'error': + return 'TOOL_ERROR' + return None + @experimental class EditFileTool(BaseTool): @@ -402,3 +420,9 @@ async def run_async( new_content = content.replace(old_string, new_string, 1) await self._environment.write_file(path, new_content) return {'status': 'ok', 'message': f'Edited {path}'} + + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('status') == 'error': + return 'TOOL_ERROR' + return None diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index c5db706d8b..6477fd6894 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -255,6 +255,12 @@ async def run_async( return await self._invoke_callable(self.func, args_to_call) + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get('error'): + return 'TOOL_ERROR' + return None + async def _invoke_callable( self, target: Callable[..., Any], args_to_call: dict[str, Any] ) -> Any: diff --git a/src/google/adk/tools/google_tool.py b/src/google/adk/tools/google_tool.py index c6ab67bb9c..f4294b76cf 100644 --- a/src/google/adk/tools/google_tool.py +++ b/src/google/adk/tools/google_tool.py @@ -106,6 +106,12 @@ async def run_async( "error_details": str(ex), } + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("status") == "ERROR": + return "TOOL_ERROR" + return None + async def _run_async_with_credential( self, credentials: Credentials, diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 6a24651f92..4acc4ff847 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -461,6 +461,12 @@ async def _run_async_impl( ) return result + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("isError"): + return "MCP_TOOL_ERROR" + return None + def _resolve_progress_callback( self, tool_context: ToolContext ) -> Optional[ProgressFnT]: diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index fa32ce932a..9516536a94 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -554,6 +554,12 @@ async def call( self._logger.debug("API Response (non-JSON): %s", response.text) return {"text": response.text} # Return text if not JSON + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("error"): + return "HTTP_ERROR" + return None + def __str__(self): return ( f'RestApiTool(name="{self.name}", description="{self.description}",' diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index ef579d8256..7d8e06368e 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -24,6 +24,7 @@ import logging import mimetypes from typing import Any +from typing import Optional from typing import TYPE_CHECKING import warnings @@ -249,6 +250,13 @@ async def run_async( "frontmatter": skill.frontmatter.model_dump(), } + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("error"): + error_code = response.get("error_code") + return error_code if error_code else "TOOL_ERROR" + return None + @experimental(FeatureName.SKILL_TOOLSET) class LoadSkillResourceTool(BaseTool): @@ -361,6 +369,13 @@ async def run_async( "content": content, } + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("error"): + error_code = response.get("error_code") + return error_code if error_code else "TOOL_ERROR" + return None + @override async def process_llm_request( self, *, tool_context: ToolContext, llm_request: Any @@ -873,6 +888,13 @@ async def run_async( positional_args, # pylint: disable=protected-access ) + def _detect_error_in_response(self, response: Any) -> Optional[str]: + """Telemetry hook: returns an error type if the response indicates an error.""" + if isinstance(response, dict) and response.get("error"): + error_code = response.get("error_code") + return error_code if error_code else "TOOL_ERROR" + return None + @experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index be638de44c..71834d2c01 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -15,8 +15,11 @@ import asyncio from typing import Any from typing import Callable +from unittest import mock +from fastapi.openapi.models import HTTPBearer from google.adk.agents.llm_agent import Agent +from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.events.ui_widget import UiWidget @@ -24,8 +27,10 @@ from google.adk.flows.llm_flows.functions import handle_function_calls_async from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events +from google.adk.tools.base_tool import BaseTool from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_confirmation import ToolConfirmation from google.adk.tools.tool_context import ToolContext from google.genai import types import pytest @@ -1319,3 +1324,263 @@ def simple_fn() -> dict[str, str]: assert result_parallel is not None assert result_parallel.live_session_id == 'test-live-session-id-parallel' + + +class _MockControlSignalTool(BaseTool): + """A tool that simulates requesting confirmation or OAuth authentication.""" + + def __init__(self, name: str, behavior: str): + super().__init__(name=name, description='Simulated control tool') + self.behavior = behavior + + async def run_async(self, *, args, tool_context): + if self.behavior == 'confirm': + tool_context.actions.requested_tool_confirmations = { + 'fc_test_confirm': ToolConfirmation(hint='Authorize execution?') + } + return {'error': 'This tool requires user approval.'} + elif self.behavior == 'auth': + tool_context.actions.requested_auth_configs = { + 'fc_test_auth': AuthConfig(auth_scheme=HTTPBearer()) + } + return {'error': 'Please complete OAuth setup.'} + + def _detect_error_in_response(self, response: Any) -> str | None: + if isinstance(response, dict) and 'error' in response: + return 'TOOL_ERROR' + return None + + +class _ErrorDetectingTool(BaseTool): + """A test tool whose _detect_error_in_response raises an exception.""" + + async def run_async(self, *, args, tool_context): + return {'result': 'result'} + + def _detect_error_in_response(self, response: Any) -> str | None: + raise RuntimeError('detection exploded') + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +@pytest.mark.parametrize( + 'mock_response,expected_error_type', + [ + ({'error': 'Internal component timeout'}, 'TOOL_ERROR'), + ({'result': 'Execution succeeded'}, None), + ], + ids=['dict_error_recorded', 'success_dict_ignored'], +) +async def test_e2e_telemetry_error_classification( + monkeypatch, handle_function_calls, mock_response, expected_error_type +): + """E2E: asserts that tool outputs successfully translate to targeted OTel span error attributes.""" + recorded_calls = [] + + # Intercept trace_tool_call to capture final telemetry state + monkeypatch.setattr( + 'google.adk.telemetry._instrumentation.tracing.trace_tool_call', + lambda **kw: recorded_calls.append(kw), + ) + + tool = FunctionTool(func=lambda: mock_response) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall(name=tool.name, args={}, id='fc_test') + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + ) + + await handle_function_calls(invocation_context, event, {tool.name: tool}) + + assert len(recorded_calls) == 1 + assert recorded_calls[0]['error_type'] == expected_error_type + assert recorded_calls[0]['error'] is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_exception_takes_precedence_over_dict_error( + monkeypatch, handle_function_calls +): + """End-to-end integration: exception takes strict precedence over manual dict error_type.""" + recorded_calls = [] + monkeypatch.setattr( + 'google.adk.telemetry._instrumentation.tracing.trace_tool_call', + lambda **kw: recorded_calls.append(kw), + ) + + def mock_crashing_func(): + raise ValueError('Fatal arithmetic error') + + tool = FunctionTool(func=mock_crashing_func) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall( + name=tool.name, args={}, id='fc_test_exception' + ) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + ) + + with pytest.raises(ValueError, match='Fatal arithmetic error'): + await handle_function_calls(invocation_context, event, {tool.name: tool}) + + assert len(recorded_calls) == 1 + assert isinstance(recorded_calls[0]['error'], ValueError) + assert recorded_calls[0]['error_type'] is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_detection_skipped_when_confirmation_requested( + monkeypatch, handle_function_calls +): + """E2E confirmation verification: control prompt avoids polluting telemetry with TOOL_ERROR.""" + recorded_calls = [] + monkeypatch.setattr( + 'google.adk.telemetry._instrumentation.tracing.trace_tool_call', + lambda **kw: recorded_calls.append(kw), + ) + + tool = _MockControlSignalTool(name='confirm_tool', behavior='confirm') + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall( + name=tool.name, args={}, id='fc_test_confirm' + ) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + ) + + await handle_function_calls(invocation_context, event, {tool.name: tool}) + + assert len(recorded_calls) == 1 + assert recorded_calls[0]['error_type'] is None + assert recorded_calls[0]['error'] is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_detection_skipped_when_auth_requested( + monkeypatch, handle_function_calls +): + """E2E OAuth verification: authenticate control prompt avoids polluting telemetry with TOOL_ERROR.""" + recorded_calls = [] + monkeypatch.setattr( + 'google.adk.telemetry._instrumentation.tracing.trace_tool_call', + lambda **kw: recorded_calls.append(kw), + ) + + tool = _MockControlSignalTool(name='auth_tool', behavior='auth') + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall(name=tool.name, args={}, id='fc_test_auth') + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + ) + + await handle_function_calls(invocation_context, event, {tool.name: tool}) + + assert len(recorded_calls) == 1 + assert recorded_calls[0]['error_type'] is None + assert recorded_calls[0]['error'] is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_detection_exception_does_not_break_tool_call( + monkeypatch, handle_function_calls +): + """Safety Verification: telemetry errors during error parsing are caught cleanly, not crashing tool calls.""" + recorded_calls = [] + monkeypatch.setattr( + 'google.adk.telemetry._instrumentation.tracing.trace_tool_call', + lambda **kw: recorded_calls.append(kw), + ) + + tool = _ErrorDetectingTool( + name='buggy_telemetry_tool', description='raises on tel' + ) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall( + name=tool.name, args={}, id='fc_test_buggy' + ) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + ) + + result_event = await handle_function_calls( + invocation_context, event, {tool.name: tool} + ) + + assert result_event is not None + assert result_event.content.parts[0].function_response.response == { + 'result': 'result' + } + + assert len(recorded_calls) == 1 + assert recorded_calls[0]['error_type'] is None + assert recorded_calls[0]['error'] is None diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index c0e4cc20b9..b497a5f0b6 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -1397,3 +1397,245 @@ def test_safe_json_serialize_no_whitespaces_circular_dict_returns_not_serializab obj = {} obj['self'] = obj assert _safe_json_serialize_no_whitespaces(obj) == '' + + +# --------------------------------------------------------------------------- +# Tests for _detect_error_in_response +# --------------------------------------------------------------------------- + + +class _ErrorDetectingTool(BaseTool): + """A test tool whose _detect_error_in_response raises.""" + + async def run_async(self, *, args, tool_context): + return 'result' + + def _detect_error_in_response(self, response: Any) -> Optional[str]: + raise RuntimeError('detection exploded') + + +def test_base_tool_does_not_define_detect_error_in_response(): + """BaseTool intentionally does not expose _detect_error_in_response as a public hook.""" + tool = SimpleTestTool(name='t', description='d') + # The hook is opt-in per subclass; BaseTool itself must not declare it so + # that telemetry callers can use getattr(...) to skip detection. + assert not hasattr(tool, '_detect_error_in_response') + + +def test_detect_error_function_tool_error(): + from google.adk.tools.function_tool import FunctionTool + + tool = FunctionTool(func=lambda: None) + assert ( + tool._detect_error_in_response({'error': 'missing arg'}) == 'TOOL_ERROR' + ) + + +def test_detect_error_function_tool_no_error(): + from google.adk.tools.function_tool import FunctionTool + + tool = FunctionTool(func=lambda: None) + assert tool._detect_error_in_response({'result': 'ok'}) is None + assert tool._detect_error_in_response('plain string') is None + assert tool._detect_error_in_response(None) is None + + +def test_detect_error_rest_api_tool(): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool + + tool = RestApiTool.__new__(RestApiTool) + assert ( + tool._detect_error_in_response({'error': 'Status Code: 404'}) + == 'HTTP_ERROR' + ) + assert tool._detect_error_in_response({'result': 'ok'}) is None + assert tool._detect_error_in_response({'text': 'html response'}) is None + + +def test_detect_error_mcp_tool(): + from google.adk.tools.mcp_tool.mcp_tool import McpTool as AdkMcpTool + + tool = AdkMcpTool.__new__(AdkMcpTool) + assert ( + tool._detect_error_in_response({'isError': True, 'content': []}) + == 'MCP_TOOL_ERROR' + ) + assert ( + tool._detect_error_in_response({'isError': False, 'content': []}) is None + ) + assert tool._detect_error_in_response({'content': [{'text': 'ok'}]}) is None + + +def test_detect_error_google_tool(): + from google.adk.tools.google_tool import GoogleTool + + tool = GoogleTool.__new__(GoogleTool) + assert ( + tool._detect_error_in_response( + {'status': 'ERROR', 'error_details': 'fail'} + ) + == 'TOOL_ERROR' + ) + assert tool._detect_error_in_response({'status': 'OK', 'data': []}) is None + assert ( + tool._detect_error_in_response({'error': 'something'}) is None + ) # GoogleTool checks status, not error key + + +def test_detect_error_bash_tool(): + from google.adk.tools.bash_tool import ExecuteBashTool + + tool = ExecuteBashTool.__new__(ExecuteBashTool) + assert ( + tool._detect_error_in_response({'error': 'Execution failed'}) + == 'TOOL_ERROR' + ) + assert ( + tool._detect_error_in_response( + {'error': 'timeout', 'stdout': '', 'stderr': ''} + ) + == 'TOOL_ERROR' + ) + assert ( + tool._detect_error_in_response({'stdout': 'ok', 'returncode': 0}) is None + ) + + +def _environment_tool_classes(): + from google.adk.tools.environment._tools import EditFileTool + from google.adk.tools.environment._tools import ExecuteTool + from google.adk.tools.environment._tools import ReadFileTool + from google.adk.tools.environment._tools import WriteFileTool + + return [ExecuteTool, ReadFileTool, WriteFileTool, EditFileTool] + + +@pytest.mark.parametrize( + 'cls', + _environment_tool_classes(), + ids=lambda c: c.__name__, +) +@pytest.mark.parametrize( + 'response,expected', + [ + ({'status': 'error', 'error': 'fail'}, 'TOOL_ERROR'), + ({'status': 'ok', 'message': 'done'}, None), + # Environment tools check status, not the error key. + ({'error': 'something'}, None), + ], + ids=['status_error', 'status_ok', 'error_key_only'], +) +def test_detect_error_environment_tools(cls, response, expected): + tool = cls.__new__(cls) + assert tool._detect_error_in_response(response) == expected + + +@pytest.mark.parametrize( + 'cls_name', + ['LoadSkillTool', 'LoadSkillResourceTool', 'RunSkillScriptTool'], +) +@pytest.mark.parametrize( + 'response,expected', + [ + ( + {'error': 'missing', 'error_code': 'INVALID_ARGUMENTS'}, + 'INVALID_ARGUMENTS', + ), + ({'error': 'generic'}, 'TOOL_ERROR'), + ({'skill_name': 'x', 'instructions': 'y'}, None), + ], + ids=['with_error_code', 'error_no_code', 'no_error'], +) +def test_detect_error_skill_tools(cls_name, response, expected): + skill_toolset = pytest.importorskip('google.adk.tools.skill_toolset') + cls = getattr(skill_toolset, cls_name) + tool = cls.__new__(cls) + assert tool._detect_error_in_response(response) == expected + + +def test_detect_error_discovery_engine_search_tool(): + mod = pytest.importorskip('google.adk.tools.discovery_engine_search_tool') + DiscoveryEngineSearchTool = mod.DiscoveryEngineSearchTool + + tool = DiscoveryEngineSearchTool.__new__(DiscoveryEngineSearchTool) + assert ( + tool._detect_error_in_response( + {'status': 'error', 'error_message': 'fail'} + ) + == 'TOOL_ERROR' + ) + assert tool._detect_error_in_response({'status': 'ok', 'results': []}) is None + + +# --------------------------------------------------------------------------- +# Tests for trace_tool_call with error_type parameter +# --------------------------------------------------------------------------- + + +def test_trace_tool_call_with_error_type( + monkeypatch, mock_span_fixture, mock_tool_fixture +): + """error_type sets the span error.type attribute when no exception.""" + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + trace_tool_call( + tool=mock_tool_fixture, + args={'x': 1}, + function_response_event=None, + error=None, + error_type='HTTP_ERROR', + ) + + mock_span_fixture.set_attribute.assert_any_call('error.type', 'HTTP_ERROR') + + +def test_trace_tool_call_error_takes_precedence_over_error_type( + monkeypatch, mock_span_fixture, mock_tool_fixture +): + """When both error and error_type are provided, error takes precedence.""" + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + trace_tool_call( + tool=mock_tool_fixture, + args={'x': 1}, + function_response_event=None, + error=ValueError('boom'), + error_type='HTTP_ERROR', + ) + + # ValueError should be set, not HTTP_ERROR. + mock_span_fixture.set_attribute.assert_any_call('error.type', 'ValueError') + error_type_calls = [ + c + for c in mock_span_fixture.set_attribute.call_args_list + if c == mock.call('error.type', mock.ANY) + ] + assert len(error_type_calls) == 1 + + +def test_trace_tool_call_no_error_no_error_type( + monkeypatch, mock_span_fixture, mock_tool_fixture +): + """When neither error nor error_type is set, no error.type attribute.""" + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + trace_tool_call( + tool=mock_tool_fixture, + args={'x': 1}, + function_response_event=None, + error=None, + error_type=None, + ) + + error_type_calls = [ + c + for c in mock_span_fixture.set_attribute.call_args_list + if c == mock.call('error.type', mock.ANY) + ] + assert len(error_type_calls) == 0