Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
dependencies = [
"pydantic>=2.11.3",
"openai>=1.3.0",
"mcp>=1.10.1",
"mcp<1.23.4,>=1.10.1",
"aiohttp",
"httpx>=0.27.0",
"httpx-sse>=0.4.0",
Expand Down
197 changes: 197 additions & 0 deletions tests/tools/mcp_tool/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#
# tRPC-Agent-Python is licensed under Apache-2.0.

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from mcp import types as mcp_types
from mcp import StdioServerParameters as McpStdioServerParameters
from mcp.types import ListToolsResult, Tool as McpBaseTool

Expand All @@ -26,6 +28,13 @@ def _stdio_conn():
)


def _server_capabilities(list_changed: bool | None = None):
tools_capability = None
if list_changed is not None:
tools_capability = mcp_types.ToolsCapability(listChanged=list_changed)
return mcp_types.ServerCapabilities(tools=tools_capability)


# ---------------------------------------------------------------------------
# Tests: __init__
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +79,15 @@ def test_session_group_params_custom(self):
ts = MCPToolset(connection_params=_stdio_conn(), session_group_params={"key": "val"})
assert ts._session_group_params == {"key": "val"}

def test_tools_cache_enabled_by_default(self):
ts = MCPToolset(connection_params=_stdio_conn())
assert ts._cache_tools is True
assert ts._tools_cache_ttl == 60.0

def test_rejects_negative_tools_cache_ttl(self):
with pytest.raises(ValueError, match="tools_cache_ttl must be non-negative"):
MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=-1)


# ---------------------------------------------------------------------------
# Tests: _checker_required_params
Expand Down Expand Up @@ -294,6 +312,185 @@ async def test_get_tools_with_custom_mcp_tool_cls(self):
custom_cls.assert_called_once()
assert len(tools) == 1

@pytest.mark.asyncio
async def test_get_tools_reuses_cached_list_tools_response(self):
ts = MCPToolset(connection_params=_stdio_conn())

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_mgr.create_session = AsyncMock(return_value=mock_session)

mcp_tools = [
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]
mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=mcp_tools))

with patch.object(ts, "initialize"):
ts._mcp_session_manager = mock_mgr
first = await ts.get_tools()
second = await ts.get_tools()

assert [tool.name for tool in first] == ["tool_a"]
assert [tool.name for tool in second] == ["tool_a"]
mock_session.list_tools.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_tools_can_disable_tools_cache(self):
ts = MCPToolset(connection_params=_stdio_conn(), cache_tools=False)

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_mgr.create_session = AsyncMock(return_value=mock_session)

mock_session.list_tools = AsyncMock(
return_value=ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]
))

with patch.object(ts, "initialize"):
ts._mcp_session_manager = mock_mgr
await ts.get_tools()
await ts.get_tools()

assert mock_session.list_tools.await_count == 2

@pytest.mark.asyncio
async def test_clear_tools_cache_forces_refresh(self):
ts = MCPToolset(connection_params=_stdio_conn())

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_mgr.create_session = AsyncMock(return_value=mock_session)

mock_session.list_tools = AsyncMock(
side_effect=[
ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]),
ListToolsResult(
tools=[
McpBaseTool(name="tool_b", description="desc_b", inputSchema={"type": "object"}),
]),
])

with patch.object(ts, "initialize"):
ts._mcp_session_manager = mock_mgr
first = await ts.get_tools()
ts.clear_tools_cache()
second = await ts.get_tools()

assert [tool.name for tool in first] == ["tool_a"]
assert [tool.name for tool in second] == ["tool_b"]
assert mock_session.list_tools.await_count == 2

@pytest.mark.asyncio
async def test_tools_cache_ttl_expires(self):
ts = MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=1)

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_mgr.create_session = AsyncMock(return_value=mock_session)

mock_session.list_tools = AsyncMock(
side_effect=[
ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]),
ListToolsResult(
tools=[
McpBaseTool(name="tool_b", description="desc_b", inputSchema={"type": "object"}),
]),
])

with patch.object(ts, "initialize"), patch(
"trpc_agent_sdk.tools.mcp_tool._mcp_toolset.time.monotonic",
side_effect=[100.0, 100.5, 101.1, 101.1, 101.1],
):
ts._mcp_session_manager = mock_mgr
first = await ts.get_tools()
cached = await ts.get_tools()
refreshed = await ts.get_tools()

assert [tool.name for tool in first] == ["tool_a"]
assert [tool.name for tool in cached] == ["tool_a"]
assert [tool.name for tool in refreshed] == ["tool_b"]
assert mock_session.list_tools.await_count == 2

@pytest.mark.asyncio
async def test_list_changed_capability_uses_notification_driven_cache(self):
ts = MCPToolset(connection_params=_stdio_conn(), tools_cache_ttl=1)

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_session.get_server_capabilities = MagicMock(return_value=_server_capabilities(list_changed=True))
mock_mgr.create_session = AsyncMock(return_value=mock_session)

mock_session.list_tools = AsyncMock(
return_value=ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]
))

with patch.object(ts, "initialize"), patch(
"trpc_agent_sdk.tools.mcp_tool._mcp_toolset.time.monotonic",
return_value=100.0,
):
ts._mcp_session_manager = mock_mgr
first = await ts.get_tools()
second = await ts.get_tools()

assert [tool.name for tool in first] == ["tool_a"]
assert [tool.name for tool in second] == ["tool_a"]
mock_session.list_tools.assert_awaited_once()

@pytest.mark.asyncio
async def test_tool_list_changed_notification_clears_cache_and_chains_handler(self):
user_message_handler = AsyncMock()
ts = MCPToolset(
connection_params=_stdio_conn(),
session_group_params={"message_handler": user_message_handler},
)
ts._tools_cache = ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
])
ts._tools_cache_updated_at = 100.0

params = ts._build_session_group_params()
notification = mcp_types.ServerNotification(mcp_types.ToolListChangedNotification())
await params["message_handler"](notification)

assert ts._tools_cache is None
assert ts._tools_cache_updated_at is None
user_message_handler.assert_awaited_once_with(notification)

@pytest.mark.asyncio
async def test_concurrent_get_tools_shares_cache_fill(self):
ts = MCPToolset(connection_params=_stdio_conn())

mock_mgr = MagicMock(spec=MCPSessionManager)
mock_session = AsyncMock()
mock_mgr.create_session = AsyncMock(return_value=mock_session)
mock_session.list_tools = AsyncMock(
return_value=ListToolsResult(
tools=[
McpBaseTool(name="tool_a", description="desc_a", inputSchema={"type": "object"}),
]
))

with patch.object(ts, "initialize"):
ts._mcp_session_manager = mock_mgr
first, second = await asyncio.gather(ts.get_tools(), ts.get_tools())

assert [tool.name for tool in first] == ["tool_a"]
assert [tool.name for tool in second] == ["tool_a"]
mock_session.list_tools.assert_awaited_once()


# ---------------------------------------------------------------------------
# Tests: close
Expand Down
Loading
Loading