diff --git a/pyproject.toml b/pyproject.toml index 6c47d4d5..b986bed7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "pydantic>=2.11.3", "openai>=1.3.0", - "mcp>=1.10.1", + "mcp<1.23.4,>=1.10.1", "aiohttp", "httpx>=0.27.0", "httpx-sse>=0.4.0", diff --git a/tests/tools/mcp_tool/test_mcp_toolset.py b/tests/tools/mcp_tool/test_mcp_toolset.py index 2a864ec4..a8257017 100644 --- a/tests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/tools/mcp_tool/test_mcp_toolset.py @@ -4,9 +4,11 @@ # # tRPC-Agent-Python is licensed under Apache-2.0. +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest +from mcp import types as mcp_types from mcp import StdioServerParameters as McpStdioServerParameters from mcp.types import ListToolsResult, Tool as McpBaseTool @@ -26,6 +28,13 @@ def _stdio_conn(): ) +def _server_capabilities(list_changed: bool | None = None): + tools_capability = None + if list_changed is not None: + tools_capability = mcp_types.ToolsCapability(listChanged=list_changed) + return mcp_types.ServerCapabilities(tools=tools_capability) + + # --------------------------------------------------------------------------- # Tests: __init__ # --------------------------------------------------------------------------- @@ -70,6 +79,15 @@ def test_session_group_params_custom(self): ts = MCPToolset(connection_params=_stdio_conn(), session_group_params={"key": "val"}) assert ts._session_group_params == {"key": "val"} + def test_tools_cache_enabled_by_default(self): + ts = MCPToolset(connection_params=_stdio_conn()) + assert ts._cache_tools is True + assert ts._tools_cache_ttl == 60.0 + + def test_rejects_negative_tools_cache_ttl(self): + with pytest.raises(ValueError, match="tools_cache_ttl must be non-negative"): + MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=-1) + # --------------------------------------------------------------------------- # Tests: _checker_required_params @@ -294,6 +312,185 @@ async def test_get_tools_with_custom_mcp_tool_cls(self): custom_cls.assert_called_once() assert len(tools) == 1 + @pytest.mark.asyncio + async def test_get_tools_reuses_cached_list_tools_response(self): + ts = MCPToolset(connection_params=_stdio_conn()) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_mgr.create_session = AsyncMock(return_value=mock_session) + + mcp_tools = [ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ] + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=mcp_tools)) + + with patch.object(ts, "initialize"): + ts._mcp_session_manager = mock_mgr + first = await ts.get_tools() + second = await ts.get_tools() + + assert [tool.name for tool in first] == ["tool_a"] + assert [tool.name for tool in second] == ["tool_a"] + mock_session.list_tools.assert_awaited_once() + + @pytest.mark.asyncio + async def test_get_tools_can_disable_tools_cache(self): + ts = MCPToolset(connection_params=_stdio_conn(), cache_tools=False) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_mgr.create_session = AsyncMock(return_value=mock_session) + + mock_session.list_tools = AsyncMock( + return_value=ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ] + )) + + with patch.object(ts, "initialize"): + ts._mcp_session_manager = mock_mgr + await ts.get_tools() + await ts.get_tools() + + assert mock_session.list_tools.await_count == 2 + + @pytest.mark.asyncio + async def test_clear_tools_cache_forces_refresh(self): + ts = MCPToolset(connection_params=_stdio_conn()) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_mgr.create_session = AsyncMock(return_value=mock_session) + + mock_session.list_tools = AsyncMock( + side_effect=[ + ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ]), + ListToolsResult( + tools=[ + McpBaseTool(name="tool_b", description="desc_b", inputSchema={"type": "object"}), + ]), + ]) + + with patch.object(ts, "initialize"): + ts._mcp_session_manager = mock_mgr + first = await ts.get_tools() + ts.clear_tools_cache() + second = await ts.get_tools() + + assert [tool.name for tool in first] == ["tool_a"] + assert [tool.name for tool in second] == ["tool_b"] + assert mock_session.list_tools.await_count == 2 + + @pytest.mark.asyncio + async def test_tools_cache_ttl_expires(self): + ts = MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=1) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_mgr.create_session = AsyncMock(return_value=mock_session) + + mock_session.list_tools = AsyncMock( + side_effect=[ + ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ]), + ListToolsResult( + tools=[ + McpBaseTool(name="tool_b", description="desc_b", inputSchema={"type": "object"}), + ]), + ]) + + with patch.object(ts, "initialize"), patch( + "trpc_agent_sdk.tools.mcp_tool._mcp_toolset.time.monotonic", + side_effect=[100.0, 100.5, 101.1, 101.1, 101.1], + ): + ts._mcp_session_manager = mock_mgr + first = await ts.get_tools() + cached = await ts.get_tools() + refreshed = await ts.get_tools() + + assert [tool.name for tool in first] == ["tool_a"] + assert [tool.name for tool in cached] == ["tool_a"] + assert [tool.name for tool in refreshed] == ["tool_b"] + assert mock_session.list_tools.await_count == 2 + + @pytest.mark.asyncio + async def test_list_changed_capability_uses_notification_driven_cache(self): + ts = MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=1) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_session.get_server_capabilities = MagicMock(return_value=_server_capabilities(list_changed=True)) + mock_mgr.create_session = AsyncMock(return_value=mock_session) + + mock_session.list_tools = AsyncMock( + return_value=ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ] + )) + + with patch.object(ts, "initialize"), patch( + "trpc_agent_sdk.tools.mcp_tool._mcp_toolset.time.monotonic", + return_value=100.0, + ): + ts._mcp_session_manager = mock_mgr + first = await ts.get_tools() + second = await ts.get_tools() + + assert [tool.name for tool in first] == ["tool_a"] + assert [tool.name for tool in second] == ["tool_a"] + mock_session.list_tools.assert_awaited_once() + + @pytest.mark.asyncio + async def test_tool_list_changed_notification_clears_cache_and_chains_handler(self): + user_message_handler = AsyncMock() + ts = MCPToolset( + connection_params=_stdio_conn(), + session_group_params={"message_handler": user_message_handler}, + ) + ts._tools_cache = ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ]) + ts._tools_cache_updated_at = 100.0 + + params = ts._build_session_group_params() + notification = mcp_types.ServerNotification(mcp_types.ToolListChangedNotification()) + await params["message_handler"](notification) + + assert ts._tools_cache is None + assert ts._tools_cache_updated_at is None + user_message_handler.assert_awaited_once_with(notification) + + @pytest.mark.asyncio + async def test_concurrent_get_tools_shares_cache_fill(self): + ts = MCPToolset(connection_params=_stdio_conn()) + + mock_mgr = MagicMock(spec=MCPSessionManager) + mock_session = AsyncMock() + mock_mgr.create_session = AsyncMock(return_value=mock_session) + mock_session.list_tools = AsyncMock( + return_value=ListToolsResult( + tools=[ + McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}), + ] + )) + + with patch.object(ts, "initialize"): + ts._mcp_session_manager = mock_mgr + first, second = await asyncio.gather(ts.get_tools(), ts.get_tools()) + + assert [tool.name for tool in first] == ["tool_a"] + assert [tool.name for tool in second] == ["tool_a"] + mock_session.list_tools.assert_awaited_once() + # --------------------------------------------------------------------------- # Tests: close diff --git a/trpc_agent_sdk/tools/mcp_tool/_mcp_toolset.py b/trpc_agent_sdk/tools/mcp_tool/_mcp_toolset.py index 38afd0a3..36e08e04 100644 --- a/trpc_agent_sdk/tools/mcp_tool/_mcp_toolset.py +++ b/trpc_agent_sdk/tools/mcp_tool/_mcp_toolset.py @@ -25,11 +25,17 @@ from __future__ import annotations +import asyncio +import inspect +import time +from typing import cast from typing import List from typing import Optional from typing import Union from typing_extensions import override +from mcp import ClientSession +from mcp import types as mcp_types from mcp.types import ListToolsResult from trpc_agent_sdk.abc import ToolPredicate @@ -81,7 +87,9 @@ def __init__(self, mcp_tool_cls=MCPTool, filters_name: Optional[list[str]] = None, filters: Optional[list[BaseFilter]] = None, - session_group_params: Optional[dict] = None): + session_group_params: Optional[dict] = None, + cache_tools: bool = True, + tools_cache_ttl: Optional[float] = 60.0): """Initializes the MCPToolset. Args: @@ -103,10 +111,17 @@ def __init__(self, filters_name: List of filter names to apply to the tools filters: List of filter instances to apply to the tools session_group_params: Optional parameters for session group management + cache_tools: Whether to cache the MCP server's list_tools response. + tools_cache_ttl: Cache lifetime in seconds for MCP servers that do not + support tools.listChanged notifications. Servers that support + listChanged use notification-driven invalidation instead. """ super().__init__(tool_filter=tool_filter, is_include_all_tools=is_include_all_tools) + if tools_cache_ttl is not None and tools_cache_ttl < 0: + raise ValueError("tools_cache_ttl must be non-negative.") + self._connection_params = connection_params self._mcp_tool_cls = mcp_tool_cls # Create the session manager that will handle the MCP connection @@ -114,6 +129,11 @@ def __init__(self, self._filters = filters self._filters_name = filters_name self._session_group_params = session_group_params or {} + self._cache_tools = cache_tools + self._tools_cache_ttl = tools_cache_ttl + self._tools_cache_lock = asyncio.Lock() + self._tools_cache: ListToolsResult | None = None + self._tools_cache_updated_at: float | None = None def _checker_required_params(self): """Validates that all required parameters are properly initialized. @@ -126,6 +146,81 @@ def _checker_required_params(self): if not self._mcp_session_manager: raise ValueError("_mcp_session_manager is None.") + def clear_tools_cache(self) -> None: + """Clears the cached MCP tool definitions. + + Call this when the MCP server's tool set is known to have changed and + the next get_tools call should re-query list_tools. + """ + self._tools_cache = None + self._tools_cache_updated_at = None + + def _server_supports_tool_list_changed(self, session: ClientSession) -> bool: + """Returns whether the server can notify client about tool list changes.""" + try: + get_capabilities = getattr(session, "get_server_capabilities", None) + if get_capabilities is None: + return False + capabilities = get_capabilities() + if inspect.isawaitable(capabilities): + close = getattr(capabilities, "close", None) + if close is not None: + close() + return False + except Exception: # pylint: disable=broad-except + return False + + tools_capability = getattr(capabilities, "tools", None) + return getattr(tools_capability, "listChanged", False) is True + + def _is_tools_cache_valid(self, session: ClientSession) -> bool: + """Returns whether the cached list_tools response can be reused.""" + if not self._cache_tools or self._tools_cache is None: + return False + if self._server_supports_tool_list_changed(session): + return True + if self._tools_cache_ttl is None: + return False + if self._tools_cache_updated_at is None: + return False + return time.monotonic() - self._tools_cache_updated_at < self._tools_cache_ttl + + async def _get_tools_response(self, session: ClientSession) -> ListToolsResult: + """Returns MCP tool definitions, using cache when enabled.""" + if not self._cache_tools: + return await session.list_tools() + + if self._is_tools_cache_valid(session): + return cast(ListToolsResult, self._tools_cache) + + async with self._tools_cache_lock: + if self._is_tools_cache_valid(session): + return cast(ListToolsResult, self._tools_cache) + + tools_response: ListToolsResult = await session.list_tools() + self._tools_cache = tools_response + self._tools_cache_updated_at = time.monotonic() + return tools_response + + def _build_session_group_params(self) -> dict: + """Builds ClientSession params with tool-change notification handling.""" + params = dict(self._session_group_params) + if not self._cache_tools: + return params + + user_message_handler = params.get("message_handler") + + async def message_handler(message): + if (isinstance(message, mcp_types.ServerNotification) + and isinstance(message.root, mcp_types.ToolListChangedNotification)): + self.clear_tools_cache() + + if user_message_handler is not None: + await user_message_handler(message) + + params["message_handler"] = message_handler + return params + @override def initialize(self) -> None: """Initialize the toolset.""" @@ -135,7 +230,7 @@ def initialize(self) -> None: self._connection_params = convert_conn_params(self._connection_params) self._mcp_session_manager = MCPSessionManager( connection_params=self._connection_params, - session_group_params=self._session_group_params, + session_group_params=self._build_session_group_params(), ) self._checker_required_params() @@ -159,7 +254,7 @@ async def get_tools( session = await self._mcp_session_manager.create_session() # Fetch available tools from the MCP server - tools_response: ListToolsResult = await session.list_tools() + tools_response = await self._get_tools_response(session) # Apply filtering based on context and tool_filter tools = [] @@ -184,6 +279,7 @@ async def close(self) -> None: gracefully to avoid blocking application shutdown. """ try: + self.clear_tools_cache() if self._mcp_session_manager is None: return await self._mcp_session_manager.close()