From 78adb726dd19e2a004785bcdc97ed353228c7f20 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 13:31:23 +0000 Subject: [PATCH 01/19] backport: copy tests/interaction/ verbatim from main (phase 0) Excludes tests/interaction from pyright until the backport restores type correctness phase by phase. --- pyproject.toml | 2 + tests/interaction/README.md | 228 ++ tests/interaction/__init__.py | 0 tests/interaction/_connect.py | 360 +++ tests/interaction/_helpers.py | 107 + tests/interaction/_requirements.py | 2816 +++++++++++++++++ tests/interaction/auth/__init__.py | 0 tests/interaction/auth/_harness.py | 465 +++ tests/interaction/auth/_provider.py | 186 ++ tests/interaction/auth/test_as_handlers.py | 300 ++ .../interaction/auth/test_authorize_token.py | 399 +++ tests/interaction/auth/test_bearer.py | 189 ++ tests/interaction/auth/test_discovery.py | 333 ++ tests/interaction/auth/test_flow.py | 239 ++ tests/interaction/auth/test_lifecycle.py | 445 +++ tests/interaction/conftest.py | 23 + tests/interaction/lowlevel/__init__.py | 0 .../interaction/lowlevel/test_cancellation.py | 234 ++ tests/interaction/lowlevel/test_completion.py | 131 + .../interaction/lowlevel/test_elicitation.py | 662 ++++ tests/interaction/lowlevel/test_flows.py | 203 ++ tests/interaction/lowlevel/test_initialize.py | 384 +++ .../interaction/lowlevel/test_list_changed.py | 136 + tests/interaction/lowlevel/test_logging.py | 127 + tests/interaction/lowlevel/test_meta.py | 63 + tests/interaction/lowlevel/test_pagination.py | 242 ++ tests/interaction/lowlevel/test_ping.py | 53 + tests/interaction/lowlevel/test_progress.py | 301 ++ tests/interaction/lowlevel/test_prompts.py | 209 ++ tests/interaction/lowlevel/test_resources.py | 309 ++ tests/interaction/lowlevel/test_roots.py | 166 + tests/interaction/lowlevel/test_sampling.py | 687 ++++ tests/interaction/lowlevel/test_timeouts.py | 114 + tests/interaction/lowlevel/test_tools.py | 512 +++ tests/interaction/lowlevel/test_wire.py | 309 ++ tests/interaction/mcpserver/__init__.py | 0 .../interaction/mcpserver/test_completion.py | 38 + tests/interaction/mcpserver/test_context.py | 271 ++ tests/interaction/mcpserver/test_prompts.py | 195 ++ tests/interaction/mcpserver/test_resources.py | 183 ++ tests/interaction/mcpserver/test_tools.py | 432 +++ tests/interaction/test_coverage.py | 105 + tests/interaction/transports/__init__.py | 0 tests/interaction/transports/_bridge.py | 169 + tests/interaction/transports/_event_store.py | 55 + tests/interaction/transports/_stdio_server.py | 63 + tests/interaction/transports/test_bridge.py | 94 + .../transports/test_client_transport_http.py | 247 ++ tests/interaction/transports/test_flows.py | 129 + .../transports/test_hosting_http.py | 344 ++ .../transports/test_hosting_resume.py | 372 +++ .../transports/test_hosting_session.py | 202 ++ tests/interaction/transports/test_sse.py | 90 + tests/interaction/transports/test_stdio.py | 143 + .../transports/test_streamable_http.py | 168 + 55 files changed, 14234 insertions(+) create mode 100644 tests/interaction/README.md create mode 100644 tests/interaction/__init__.py create mode 100644 tests/interaction/_connect.py create mode 100644 tests/interaction/_helpers.py create mode 100644 tests/interaction/_requirements.py create mode 100644 tests/interaction/auth/__init__.py create mode 100644 tests/interaction/auth/_harness.py create mode 100644 tests/interaction/auth/_provider.py create mode 100644 tests/interaction/auth/test_as_handlers.py create mode 100644 tests/interaction/auth/test_authorize_token.py create mode 100644 tests/interaction/auth/test_bearer.py create mode 100644 tests/interaction/auth/test_discovery.py create mode 100644 tests/interaction/auth/test_flow.py create mode 100644 tests/interaction/auth/test_lifecycle.py create mode 100644 tests/interaction/conftest.py create mode 100644 tests/interaction/lowlevel/__init__.py create mode 100644 tests/interaction/lowlevel/test_cancellation.py create mode 100644 tests/interaction/lowlevel/test_completion.py create mode 100644 tests/interaction/lowlevel/test_elicitation.py create mode 100644 tests/interaction/lowlevel/test_flows.py create mode 100644 tests/interaction/lowlevel/test_initialize.py create mode 100644 tests/interaction/lowlevel/test_list_changed.py create mode 100644 tests/interaction/lowlevel/test_logging.py create mode 100644 tests/interaction/lowlevel/test_meta.py create mode 100644 tests/interaction/lowlevel/test_pagination.py create mode 100644 tests/interaction/lowlevel/test_ping.py create mode 100644 tests/interaction/lowlevel/test_progress.py create mode 100644 tests/interaction/lowlevel/test_prompts.py create mode 100644 tests/interaction/lowlevel/test_resources.py create mode 100644 tests/interaction/lowlevel/test_roots.py create mode 100644 tests/interaction/lowlevel/test_sampling.py create mode 100644 tests/interaction/lowlevel/test_timeouts.py create mode 100644 tests/interaction/lowlevel/test_tools.py create mode 100644 tests/interaction/lowlevel/test_wire.py create mode 100644 tests/interaction/mcpserver/__init__.py create mode 100644 tests/interaction/mcpserver/test_completion.py create mode 100644 tests/interaction/mcpserver/test_context.py create mode 100644 tests/interaction/mcpserver/test_prompts.py create mode 100644 tests/interaction/mcpserver/test_resources.py create mode 100644 tests/interaction/mcpserver/test_tools.py create mode 100644 tests/interaction/test_coverage.py create mode 100644 tests/interaction/transports/__init__.py create mode 100644 tests/interaction/transports/_bridge.py create mode 100644 tests/interaction/transports/_event_store.py create mode 100644 tests/interaction/transports/_stdio_server.py create mode 100644 tests/interaction/transports/test_bridge.py create mode 100644 tests/interaction/transports/test_client_transport_http.py create mode 100644 tests/interaction/transports/test_flows.py create mode 100644 tests/interaction/transports/test_hosting_http.py create mode 100644 tests/interaction/transports/test_hosting_resume.py create mode 100644 tests/interaction/transports/test_hosting_session.py create mode 100644 tests/interaction/transports/test_sse.py create mode 100644 tests/interaction/transports/test_stdio.py create mode 100644 tests/interaction/transports/test_streamable_http.py diff --git a/pyproject.toml b/pyproject.toml index 6c88c8e789..0d424a2841 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,8 @@ packages = ["src/mcp"] [tool.pyright] typeCheckingMode = "strict" include = ["src/mcp", "tests", "examples/servers", "examples/snippets"] +# tests/interaction is mid-backport from main; type-checking is restored phase by phase. +exclude = ["tests/interaction"] venvPath = "." venv = ".venv" # The FastAPI style of using decorators in tests gives a `reportUnusedFunction` error. diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..be68c3b0f1 --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,228 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-process and event-driven — including the streamable HTTP, SSE, and OAuth +flows — with a single subprocess test for stdio. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The HTTP and OAuth tests drive the Starlette + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere outside the one stdio test. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ behaviour specific to one transport (sessions, resumability, framing) + auth/ OAuth flows against an in-process authorization server +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport, over the server's +real streamable HTTP app driven in-process through the streaming bridge, and over the legacy SSE +transport the same way. A test connects with `async with connect(server, ...) as client:` and +asserts the same output on every leg, because the transport is not supposed to change observable +behaviour. Tests that are tied to one transport do not use the fixture: the wire-recording tests +(their seam is the in-memory stream pair), the bare-`ClientSession` lifecycle tests, the +real-clock timeout tests (the timeout machinery is transport-independent and must not race +transport latency), and everything under `transports/`, which pins behaviour only observable on +that transport. + +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but has no test in this suite, with a precise + reason: the SDK does not implement it, the negative cannot be observed, the assertion is + schema-level rather than interaction-level, the feature is experimental (tasks), or the test + would require real-time waits the suite refuses. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. +- A test that needs a peer no real `Server` or `Client` can play (a server that answers initialize + with an unsupported version, a client that sends malformed params) plays that side of the wire by + hand over `create_client_server_memory_streams()`. This scripted-peer pattern is the suite's only + way to drive behaviour the typed API cannot produce, and the docstring of every such test says so. + +Stack a second `@requirement` decorator only when a test's natural assertions incidentally prove +another behaviour — one capabilities snapshot proving four `*:capability:declared` entries, one +input-schema identity check proving each preserved keyword. Do not build a test around covering +many requirements at once; if the assertions would be separate, write separate tests. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or +`# noqa` comments; restructure instead. Two pragmas are sanctioned in this suite's test code, both +for known-upstream tracer bugs and only after restructuring has been tried: `# pragma: no branch` +on a `with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context (reserve it for shapes that cannot collapse — a sync `with` adjacent to an +`async with`); and `# pragma: lax no cover` on a single statement that 3.11's tracer drops because +the preceding `async with` unwinds via `coro.throw()` (python/cpython#106749, wontfix on 3.11) — +this hits any test that must run statements after a `ClientSession`/`streamable_http_client` exits +but still inside an outer `async with`, and no restructure can avoid it. + +A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose +execution is timing-dependent under the in-process HTTP bridge — the POST-stream and +stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` +check in `message_router`, and the response-stream double-close guard in +`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to +strict `no cover` without first making the teardown ordering deterministic. The suite also relies +on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the +SSE leg would otherwise leak. diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..1faf4aa8d6 --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,360 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Protocol + +import httpx +from httpx_sse import ServerSentEvent, aconnect_sse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, + ): + yield client + + +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + app = lowlevel.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + auth=auth, + token_verifier=token_verifier, + auth_server_provider=auth_server_provider, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with ( + server.session_manager.run(), + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client, + ): + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..25833b0ca5 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,107 @@ +"""Shared helpers for the interaction suite. + +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. +""" + +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this alias can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..109b30fc77 --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,2816 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + transports: tuple[Transport, ...] | None = None + divergence: Divergence | None = None + deferred: str | None = None + issue: str | None = None + + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle & version negotiation + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's MUST." + ), + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." + ), + ), + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." + ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's MUST is not enforced." + ), + ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." + ), + ), + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), + ), + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A server may include an instructions string in the initialize result; the client exposes it.", + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), + deferred=( + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + ), + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), + ), + "lifecycle:version:server-fallback-latest": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "An initialize request carrying a protocol version the server does not support is answered " + "with another version the server supports — the latest one — rather than an error." + ), + ), + "lifecycle:version:reject-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), + "protocol:cancel:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "receiver does not send a response for the cancelled request." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "protocol:cancel:initialize-not-cancellable": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), + ), + "protocol:cancel:late-response-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), + ), + ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "protocol:cancel:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), + deferred=( + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + "protocol:progress:callback": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "protocol:progress:token-injected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." + ), + ), + "protocol:progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="Without a progress callback the request carries no progress token.", + ), + "protocol:progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + "protocol:timeout:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), + divergence=Divergence( + note=( + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + ), + ), + ), + "protocol:timeout:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "protocol:timeout:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:mixed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:concurrent": Requirement( + source="sdk", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", + behavior=( + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." + ), + ), + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, the server sends notifications/tools/list_changed and it reaches " + "the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + ), + "client:output-schema:validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + ), + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + ), + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + ), + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), + ), + "mcpserver:tool:extra": Requirement( + source="sdk", + behavior=( + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." + ), + ), + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." + ), + ), + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior=( + "Registering a tool whose name violates the spec's tool-naming conventions emits a warning; " + "registration still succeeds." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:tool:schema-variants": Requirement( + source="sdk", + behavior=( + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." + ), + ), + ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior="Resource annotations supplied by the server round-trip to the client in the list result.", + divergence=Divergence( + note=( + "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " + "pydantic default extra='ignore', so the value is silently dropped on parse." + ), + ), + ), + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#capabilities", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, the server sends notifications/resources/list_changed and it " + "reaches the client's handler." + ), + ), + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "resources:read:blob": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), + ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state, so whether updated " + "notifications stop after unsubscribe is entirely handler code; there is no SDK behaviour to " + "pin beyond the unsubscribe request reaching the handler (covered by resources:unsubscribe)." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." + ), + ), + ), + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + ), + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + ), + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", + ), + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, the server sends notifications/prompts/list_changed and it " + "reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + ), + "mcpserver:prompt:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." + ), + ), + ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + ), + "sampling:create:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), + ), + "sampling:create:model-preferences": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", + behavior=( + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." + ), + ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#image-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." + ), + ), + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior=( + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." + ), + ), + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", + behavior=( + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." + ), + deferred=( + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." + ), + ), + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), + ), + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "In a sampling/createMessage request, every assistant tool_use block in messages MUST be " + "matched by a tool_result with the same toolUseId in the immediately-following user message; " + "an unmatched tool_use is rejected with -32602 Invalid params." + ), + divergence=Divergence( + note=( + "The client does not validate inbound tool_use/tool_result balance; the SDK enforces " + "the rule server-side instead, before the request leaves the server (see " + "sampling:tool-use:server-preflight)." + ), + ), + deferred=( + "Not implemented on the client receive path: validation runs only on the server send path " + "(pinned by sampling:tool-use:server-preflight)." + ), + ), + "sampling:tool-use:server-preflight": Requirement( + source="sdk", + behavior=( + "The server validates tool_use/tool_result balance before sending a sampling/createMessage " + "request; an unmatched tool_use raises ValueError and the request never reaches the wire." + ), + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's MUST NOT is not enforced." + ), + ), + ), + "elicitation:form:action:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler." + ), + ), + "elicitation:form:action:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." + ), + ), + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), + ), + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." + ), + ), + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + divergence=Divergence( + note=( + "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " + "validation at the low-level session layer (the high-level Context.elicit / " + "elicit_with_validation helper enforces primitive-only fields before generating the schema)." + ), + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + divergence=Divergence( + note=( + "The client never validates outbound content; ServerSession.elicit_form returns received " + "content unvalidated (the high-level Context.elicit / elicit_with_validation helper " + "validates server-side, but the low-level session API does not)." + ), + ), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior="Opening the SSE stream delivers an `endpoint` event naming the message-POST URL as the first event.", + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "The endpoint URL carries a fresh session identifier; the server registers the session before " + "the endpoint event is sent and releases it when the stream disconnects, and a POST that names " + "no session id, a malformed session id, or an unknown session id is rejected (400/400/404)." + ), + transports=("sse",), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." + ), + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session.", + transports=("streamable-http",), + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " + "resource is accepted. Spec MUST." + ), + ), + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " + "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " + "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " + 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' + "authentication information was presented." + ), + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + ), + "hosting:auth:query-token-ignored": Requirement( + source="sdk", + behavior=( + "An access token presented in the URI query string is not accepted; the request is treated as " + "unauthenticated." + ), + transports=("streamable-http",), + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' + "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " + "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." + ), + ), + ), + "hosting:auth:as:authorize-requires-pkce": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled authorization endpoint rejects an authorize request that omits " + "`code_challenge` with `invalid_request`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:verifier-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `code_verifier` " + "does not hash to the stored `code_challenge` with `invalid_grant`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:code-single-use": Requirement( + source="sdk", + behavior=( + "An authorization code can be exchanged exactly once; a second exchange of the same code " + "is rejected with `invalid_grant`. Enforced by the provider deleting the code on first use; " + "the handler relies on `load_authorization_code` returning None." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:redirect-uri-binding": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `redirect_uri` " + "differs from the one used at authorize; the bundled authorize endpoint rejects a " + "`redirect_uri` not in the client's registered list without redirecting to it." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " + "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " + "The rejection itself is the security-relevant property and is correct." + ), + ), + ), + "hosting:auth:as:redirect-uri-scheme": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", + behavior=( + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " + "host check, so http://evil.example/callback is accepted and registered. The spec's " + "localhost-or-HTTPS rule is left to the provider implementation." + ), + ), + ), + "hosting:auth:as:token-cache-headers": Requirement( + source="sdk", + behavior=("Every token-endpoint response carries `Cache-Control: no-store` and `Pragma: no-cache`."), + transports=("streamable-http",), + ), + "hosting:auth:as:register-error-response": Requirement( + source="sdk", + behavior=( + "The bundled registration endpoint answers invalid client metadata with HTTP 400 and an " + "RFC 7591 error body." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." + ), + transports=("streamable-http",), + ), + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior=( + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source="sdk", + behavior="A 404 (session expired) on a request surfaces as an error to the caller.", + transports=("streamable-http",), + ), + "client-transport:http:session-404-reinitialize": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: the client surfaces a Session terminated error instead of " + "re-initializing (the surfaced error is pinned by client-transport:http:404-surfaces)." + ), + ), + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + ), + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." + ), + transports=("streamable-http",), + ), + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + ), + "client-transport:http:custom-client": Requirement( + source="sdk", + behavior=( + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." + ), + transports=("streamable-http",), + ), + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + ), + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." + ), + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." + ), + transports=("streamable-http",), + deferred=( + "The server's standalone GET stream emits no priming event or retry hint, so the client's " + "reconnection path always sleeps the hard-coded 1 s default; a deterministic in-process test " + "would require accepting that real-time wait. The POST-stream reconnection path is covered " + "by client-transport:http:reconnect-post-priming." + ), + ), + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", + behavior=( + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." + ), + transports=("streamable-http",), + ), + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", + behavior=( + "If the server still returns 401 after a successful authorization, the client fails instead of looping." + ), + transports=("streamable-http",), + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), + ), + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:issuer-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client rejects authorization-server metadata whose issuer does not match the URL the " + "metadata was retrieved from (RFC 8414 section 3.3)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK parses authorization-server metadata without comparing issuer to the discovery " + "URL; a mismatched issuer is accepted and the flow proceeds." + ), + ), + ), + "client-auth:authorize:error-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "An OAuth error redirect from the authorize endpoint aborts the flow before any token " + "request is issued, surfacing as an error to the caller." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The callback contract has no error form, so the client surfaces 'No authorization code " + "received' rather than the redirect's `error`/`error_description` values." + ), + ), + ), + "client-auth:authorize:offline-access-consent": Requirement( + source="sdk", + behavior=( + "When the authorization server's metadata advertises offline_access in scopes_supported and " + "the client uses the refresh_token grant, offline_access is appended to the requested scope " + "and prompt=consent is added to the authorize request." + ), + transports=("streamable-http",), + ), + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), + transports=("streamable-http",), + ), + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + ), + "client-auth:client-credentials": Requirement( + source="sdk", + behavior=( + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." + ), + transports=("streamable-http",), + ), + "client-auth:dcr:registration-error-surfaces": Requirement( + source="sdk", + behavior=( + "A 400 from the registration endpoint surfaces to the caller as an OAuthRegistrationError " + "carrying the status and the server's RFC 7591 error body." + ), + transports=("streamable-http",), + ), + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", + behavior=( + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." + ), + transports=("streamable-http",), + ), + "client-auth:invalid-client-clears-all": Requirement( + source="sdk", + behavior=( + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The token-response handlers do not parse the error body; an invalid_client or " + "unauthorized_client response leaves stored client_info untouched. The TypeScript SDK " + "clears it." + ), + ), + deferred=( + "Not implemented in the SDK: no token-response path inspects the error code to decide " + "whether to clear client_info." + ), + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " + "regardless; the spec MUST is not enforced." + ), + ), + ), + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." + ), + transports=("streamable-http",), + ), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), + ), + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." + ), + transports=("streamable-http",), + ), + "client-auth:prm-discovery:no-prm-fallback": Requirement( + source="sdk", + behavior=( + "When every protected-resource metadata probe fails, the client falls back to discovering " + "authorization-server metadata directly at the MCP server's origin (the legacy 2025-03-26 path) " + "rather than aborting." + ), + transports=("streamable-http",), + ), + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." + ), + transports=("streamable-http",), + ), + "client-auth:refresh:transparent": Requirement( + source="sdk", + behavior=( + "An access token the client considers expired is transparently refreshed before the next " + "request, using the stored refresh token; the refresh request includes the resource indicator " + "and the new token is persisted." + ), + transports=("streamable-http",), + ), + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", + behavior=( + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." + ), + transports=("streamable-http",), + ), + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", + behavior=( + "Client selects requested scope from the WWW-Authenticate scope param if present; otherwise " + "uses scopes_supported from the PRM document; otherwise omits scope." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK inserts an extra fallback step between PRM and omit: if the authorization " + "server metadata advertises scopes_supported, that list is used (client/auth/utils.py). " + "This is beyond the spec's two-step chain." + ), + ), + ), + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." + ), + transports=("streamable-http",), + ), + "client-auth:token-endpoint-auth-method": Requirement( + source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior=( + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." + ), + transports=("streamable-http",), + deferred=( + "Untestable negative through the public API: there is no path to inject a token obtained " + "elsewhere into the auth provider's state, so the absence cannot be observed end to end." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." + ), + transports=("stdio",), + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), + ), + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", + behavior=( + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." + ), + transports=("stdio",), + deferred=( + "A server that ignores stdin close takes the full PROCESS_TERMINATION_TIMEOUT (2.0 s) grace " + "period plus up to a further 2.0 s for SIGTERM/SIGKILL escalation; testing that path is " + "real-time-bound (the constant is module-level with no public override) and so is deliberately " + "excluded from this suite. Covered by tests/client/test_stdio.py." + ), + ), + "transport:stdio:stderr-passthrough": Requirement( + source="sdk", + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." + ), + transports=("streamable-http", "sse"), + ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." + ), + transports=("streamable-http", "sse"), + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), + ), + "flow:elicitation:multi-step-form": Requirement( + source="sdk", + behavior=( + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." + ), + ), + "flow:elicitation:url-at-session-init": Requirement( + source="sdk", + behavior=( + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." + ), + transports=("streamable-http",), + deferred=( + "Not implemented in the SDK: no public per-session post-initialization hook exists on either " + "server flavour (Server.lifespan runs at server startup, not per session; ServerSession " + "handles the initialized notification internally with no callback). Driving 'before any " + "client request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), + ), + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." + ), + ), + "flow:multi-client:stateful-isolation": Requirement( + source="sdk", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." + ), + transports=("streamable-http",), + ), + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." + ), + transports=("streamable-http",), + ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/auth/__init__.py b/tests/interaction/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py new file mode 100644 index 0000000000..d013364f33 --- /dev/null +++ b/tests/interaction/auth/_harness.py @@ -0,0 +1,465 @@ +"""In-process harness for the auth interaction tests. + +Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the +bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., +token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge +on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the +authorize redirect headlessly by GETing the URL through the same bridge and parsing the code +from the 302 `Location`. The whole authorization-code flow runs in one event loop with no +sockets, no threads, and no real time. +""" + +import json +from collections.abc import AsyncIterator, Callable, Mapping, Sequence +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import parse_qs, parse_qsl, urlsplit + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.client.auth import OAuthClientProvider +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +REDIRECT_URI = f"{BASE_URL}/oauth/callback" + +AppShim = Callable[[ASGIApp], ASGIApp] + + +@dataclass +class RecordedRequest: + """A snapshot of an `httpx.Request` at the moment it was sent. + + The auth flow re-yields the same `httpx.Request` object after mutating its headers in + place for the retry, so tests that need to assert on the first attempt's headers must + capture a copy rather than a live reference. `record_requests` produces these. + """ + + method: str + url: httpx.URL + headers: dict[str, str] + content: bytes + + @property + def path(self) -> str: + return self.url.path + + +def record_requests() -> tuple[list[RecordedRequest], Callable[[httpx.Request], None]]: + """Build an `on_request` callback that snapshots each request, and the list it appends to.""" + recorded: list[RecordedRequest] = [] + + def on_request(request: httpx.Request) -> None: + recorded.append( + RecordedRequest( + method=request.method, + url=request.url, + headers=dict(request.headers), + content=bytes(request.content), + ) + ) + + return recorded, on_request + + +def metadata_body(model: BaseModel, **extra: object) -> bytes: + """Serialize a metadata model to a JSON body for `shimmed_app(serve=...)`. + + `extra` keys are merged into the serialized object so a test can inject fields the model + does not declare (e.g. an unknown extension field, to prove the client's parser tolerates + unrecognized members per RFC 8414/9728 §3.2). The model itself would silently drop such + fields at construction, so they have to be added after serialization. + """ + document = model.model_dump(by_alias=True, mode="json", exclude_none=True) + document.update(extra) + return json.dumps(document).encode() + + +class StaticTokenVerifier: + """A `TokenVerifier` backed by a fixed token→`AccessToken` mapping. + + Any token string not in the mapping verifies to `None`, which the bearer middleware treats + as an unrecognized token. Tests seed the mapping with the exact token shapes (valid, expired, + wrong scope, wrong audience) they need so the resource-server gate's behaviour is asserted in + isolation from the authorization-server provider. + """ + + def __init__(self, tokens: Mapping[str, AccessToken]) -> None: + self._tokens = dict(tokens) + + async def verify_token(self, token: str) -> AccessToken | None: + return self._tokens.get(token) + + +class InMemoryTokenStorage: + """A `TokenStorage` that holds tokens and client info as instance attributes. + + Tests pre-seed `client_info` (via the constructor or by assignment) to drive the + pre-registered path, and read both attributes after the flow to assert what the SDK + persisted. + """ + + def __init__(self, *, client_info: OAuthClientInformationFull | None = None) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = client_info + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + +class HeadlessOAuth: + """Completes the authorize step in-process by following the redirect through the bridge. + + `redirect_handler` GETs the authorize URL on the bound client (with `auth=None` so the + request does not re-enter the locked auth flow), parses `code` and `state` from the 302 + `Location`, and stashes them; `callback_handler` returns the stashed pair. Tests inspect + `authorize_url` to assert what the SDK put on the authorize request. + + `state_override`: when set, `callback_handler` returns this value as the state instead of + the one parsed from the redirect, so tests can drive the state-mismatch path. + """ + + def __init__(self, *, state_override: str | None = None) -> None: + self.authorize_url: str | None = None + self.authorize_urls: list[str] = [] + self.error: str | None = None + self._state_override = state_override + self._http: httpx.AsyncClient | None = None + self._code: str = "" + self._state: str | None = None + + def bind(self, http_client: httpx.AsyncClient) -> None: + self._http = http_client + + async def redirect_handler(self, authorization_url: str) -> None: + assert self._http is not None + self.authorize_url = authorization_url + self.authorize_urls.append(authorization_url) + # auth=None is load-bearing: without it the GET re-enters OAuthClientProvider.async_auth_flow + # through its context lock and the flow deadlocks. + response = await self._http.get(authorization_url, follow_redirects=False, auth=None) + assert response.status_code == 302, f"authorize endpoint returned {response.status_code}: {response.text}" + params = parse_qs(urlsplit(response.headers["location"]).query) + self._code = params.get("code", [""])[0] + self._state = params.get("state", [None])[0] + self.error = params.get("error", [None])[0] + + async def callback_handler(self) -> tuple[str, str | None]: + return self._code, self._state_override if self._state_override is not None else self._state + + +def auth_settings( + *, required_scopes: Sequence[str] = ("mcp",), valid_scopes: Sequence[str] | None = None +) -> AuthSettings: + """Build `AuthSettings` for the co-hosted authorization + resource server. + + The issuer and resource URLs use the suite's loopback origin, which `validate_issuer_url` + accepts in lieu of HTTPS. Dynamic client registration is enabled. `valid_scopes` defaults + to `required_scopes` so a client requesting exactly those passes registration scope + validation; tests pass a wider set when they need the protected-resource metadata's + `scopes_supported` (which mirrors `required_scopes`) to differ from what the client may + register or when AS metadata should advertise additional scopes such as `offline_access`. + """ + required = list(required_scopes) + valid = list(valid_scopes) if valid_scopes is not None else required + return AuthSettings( + issuer_url=AnyHttpUrl(BASE_URL), + resource_server_url=AnyHttpUrl(f"{BASE_URL}/mcp"), + required_scopes=required, + client_registration_options=ClientRegistrationOptions( + enabled=True, valid_scopes=valid, default_scopes=required + ), + revocation_options=RevocationOptions(enabled=False), + ) + + +def oauth_client_metadata() -> OAuthClientMetadata: + """Build the client's registration metadata. + + `scope` is left unset so the SDK's scope-selection strategy chooses one from the server's + metadata before registration. + """ + return OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + ) + + +def shimmed_app( + app: ASGIApp, + *, + not_found: frozenset[str] = frozenset(), + serve: Mapping[str, bytes | tuple[int, bytes]] | None = None, +) -> ASGIApp: + """Wrap an ASGI app so specific paths return canned responses before reaching the real app. + + Paths in `serve` return the given body as `application/json` (status 200, or the supplied + status when the value is a `(status, body)` pair); paths in `not_found` return 404; + everything else reaches the wrapped app unchanged. Used by the discovery tests to make a + well-known endpoint 404 or return alternate metadata while keeping the real authorization + and MCP endpoints behind it. + """ + overrides: dict[str, tuple[int, bytes]] = { + path: value if isinstance(value, tuple) else (200, value) for path, value in (serve or {}).items() + } + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + path = scope["path"] + if path in overrides: + status, body = overrides[path] + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + return + if path in not_found: + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + +def shim( + *, not_found: frozenset[str] = frozenset(), serve: Mapping[str, bytes | tuple[int, bytes]] | None = None +) -> AppShim: + """Build an `app_shim` for `connect_with_oauth` that applies `shimmed_app` with these overrides.""" + return lambda app: shimmed_app(app, not_found=not_found, serve=serve) + + +@dataclass +class _FirstChallenge: + """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. + + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry + parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured + to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + this pattern for behaviour the real server cannot be made to produce. + """ + + app: ASGIApp + path: str + www_authenticate: str + _seen: set[str] = field(default_factory=set[str]) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == self.path and self.path not in self._seen: + self._seen.add(self.path) + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"www-authenticate", self.www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await self.app(scope, receive, send) + + +def first_challenge_shim(www_authenticate: str, *, path: str = "/mcp") -> Callable[[ASGIApp], ASGIApp]: + """Build an `app_shim` that 401s the first request to `path` with the given header value.""" + return lambda app: _FirstChallenge(app, path, www_authenticate) + + +def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) -> AppShim: + """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. + + Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up + handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the + divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this + pattern for behaviour the real server cannot be made to produce. + + The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the + first authenticated POST is the auth flow's retry of the original initialize request (yielded + after the 401 branch, where the generator ends without inspecting the response), so a 403 + there would not reach the step-up handler. + """ + seen = 0 + fired = False + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + nonlocal seen, fired + if ( + not fired + and scope["type"] == "http" + and scope["path"] == "/mcp" + and scope["method"] == "POST" + and any(name == b"authorization" for name, _ in scope["headers"]) + ): + seen += 1 + if seen < on_nth_authenticated_post: + await app(scope, receive, send) + return + fired = True + await send( + { + "type": "http.response.start", + "status": 403, + "headers": [(b"www-authenticate", www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +def m2m_token_shim(provider: InMemoryAuthorizationServerProvider, *, scopes: list[str]) -> AppShim: + """Build an `app_shim` that handles `grant_type=client_credentials` at `/token`. + + The SDK server's `TokenHandler` only routes `authorization_code` and `refresh_token`, so a + `client_credentials` request would fail discriminator validation. This shim mints a token via + `provider.mint_access_token` so the M2M client providers can complete e2e against the real + bearer middleware. The shim is harness; the SDK-under-test is the client provider, whose + outbound `/token` body the test asserts. The shim does not authenticate the client (no + credential check) because the test asserts the credentials on the recorded request, not on + the server's acceptance. + """ + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == "/token" and scope["method"] == "POST": + # The streaming bridge buffers the request body and delivers it in a single + # http.request event, so one receive is sufficient. + message = await receive() + assert not message.get("more_body", False) + form = dict(parse_qsl(message.get("body", b"").decode())) + assert form.get("grant_type") == "client_credentials", ( + f"m2m_token_shim only handles client_credentials; got {form.get('grant_type')!r}" + ) + access = provider.mint_access_token(client_id="m2m", scopes=scopes, resource=form.get("resource")) + token = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=" ".join(scopes)) + response_body = token.model_dump_json(exclude_none=True).encode() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_body)).encode()), + (b"cache-control", b"no-store"), + ], + } + ) + await send({"type": "http.response.body", "body": response_body}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +@asynccontextmanager +async def connect_with_oauth( + server: Server, + *, + provider: InMemoryAuthorizationServerProvider, + settings: AuthSettings | None = None, + storage: InMemoryTokenStorage | None = None, + client_metadata: OAuthClientMetadata | None = None, + client_metadata_url: str | None = None, + headless: HeadlessOAuth | None = None, + auth: httpx.Auth | None = None, + verify_tokens: bool = True, + app_shim: Callable[[ASGIApp], ASGIApp] | None = None, + on_request: Callable[[httpx.Request], None] | None = None, +) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: + """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. + + Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the + SDK put on the authorize request. `on_request` records every HTTP request the underlying + `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + + `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour + (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without + the bearer middleware so a flow driven by a shimmed 401 completes regardless of the granted + scopes. `app_shim` wraps the built Starlette app before it reaches the bridge transport, + for tests that need to intercept or rewrite specific server responses. + + `auth`: supply a pre-built `httpx.Auth` (such as `ClientCredentialsOAuthProvider`) to use + instead of constructing the default `OAuthClientProvider`; in that case `storage`, + `client_metadata`, `client_metadata_url`, and `headless` are unused (the yielded + `HeadlessOAuth` is never invoked and its `authorize_url` stays None). + """ + settings = settings if settings is not None else auth_settings() + storage = storage if storage is not None else InMemoryTokenStorage() + client_metadata = client_metadata if client_metadata is not None else oauth_client_metadata() + headless = headless if headless is not None else HeadlessOAuth() + + oauth = ( + auth + if auth is not None + else OAuthClientProvider( + server_url=f"{BASE_URL}/mcp", + client_metadata=client_metadata, + storage=storage, + redirect_handler=headless.redirect_handler, + callback_handler=headless.callback_handler, + client_metadata_url=client_metadata_url, + ) + ) + + app: ASGIApp = server.streamable_http_app( + auth=settings, + token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, + auth_server_provider=provider, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + if app_shim is not None: + app = app_shim(app) + + event_hooks: dict[str, list[Callable[..., Any]]] | None = None + if on_request is not None: + record = on_request + + async def hook(request: httpx.Request) -> None: + record(request) + + event_hooks = {"request": [hook]} + + async with AsyncExitStack() as stack: + await stack.enter_async_context(server.session_manager.run()) + http_client = await stack.enter_async_context( + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) + ) + headless.bind(http_client) + client = await stack.enter_async_context( + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + ) + yield client, headless diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py new file mode 100644 index 0000000000..5c88995a30 --- /dev/null +++ b/tests/interaction/auth/_provider.py @@ -0,0 +1,186 @@ +"""An in-memory implementation of the SDK's OAuth authorization-server provider protocol. + +The provider holds clients, authorization codes, refresh tokens and access tokens in plain +instance dicts so tests can inspect them; tokens are minted from `secrets.token_hex` so the +values are unique without being predictable. The behaviour mirrors what the SDK's authorization +handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` +issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction +suite drives are implemented; methods the suite does not exercise raise `NotImplementedError`. +""" + +import secrets +import time + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +_TOKEN_LIFETIME_SECONDS = 3600 + + +class InMemoryAuthorizationServerProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + """An OAuth authorization-server provider backed by in-memory dicts. + + Holds registered clients, issued codes, refresh tokens and access tokens as instance state + so tests can both drive the SDK's authorization handlers and inspect what was issued. + + Knobs: + `default_scopes`: scopes granted when an authorize request supplies none. + `deny_authorize`: every authorize request returns an `error=access_denied` redirect. + `issue_expired_first`: the first issued token's `expires_in` is in the past so the + client immediately considers it expired and refreshes; the server-side + `AccessToken.expires_at` stays in the future so the bearer middleware accepts it + on the retry that completes the connect. + `fail_next_refresh`: the next refresh-token exchange raises `invalid_grant` once. + `reject_all_tokens`: `load_access_token` returns None for every token, so the bearer + middleware 401s every authenticated request. + """ + + def __init__( + self, + *, + default_scopes: list[str] | None = None, + deny_authorize: bool = False, + issue_expired_first: bool = False, + fail_next_refresh: bool = False, + reject_all_tokens: bool = False, + ) -> None: + self._default_scopes = list(default_scopes) if default_scopes is not None else ["mcp"] + self._issuer = "http://127.0.0.1:8000" + self._deny_authorize = deny_authorize + self._issue_expired_first = issue_expired_first + self._fail_next_refresh = fail_next_refresh + self._reject_all_tokens = reject_all_tokens + self._tokens_issued = 0 + self.clients: dict[str, OAuthClientInformationFull] = {} + self.codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def _next_expires_in(self) -> int: + self._tokens_issued += 1 + if self._issue_expired_first and self._tokens_issued == 1: + return -_TOKEN_LIFETIME_SECONDS + return _TOKEN_LIFETIME_SECONDS + + def mint_access_token(self, *, client_id: str, scopes: list[str], resource: str | None = None) -> str: + """Mint and store an access token, returning its value. + + Used by the auth-code and refresh exchanges and by the M2M `/token` shim. The + server-side `expires_at` is always in the future regardless of `issue_expired_first`, + which only affects what the client is told. + """ + access = f"access_{secrets.token_hex(16)}" + self.access_tokens[access] = AccessToken( + token=access, + client_id=client_id, + scopes=scopes, + expires_at=int(time.time()) + _TOKEN_LIFETIME_SECONDS, + resource=resource, + ) + return access + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + assert client_info.client_id is not None + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Mint an authorization code immediately and return the redirect carrying it. + + A real provider would interpose user consent here; the test provider grants + unconditionally so the headless redirect handler can complete the flow in-process. + When `deny_authorize` is set, returns an `error=access_denied` redirect instead. + """ + assert client.client_id is not None + if self._deny_authorize: + return construct_redirect_uri( + str(params.redirect_uri), error="access_denied", error_description="user denied", state=params.state + ) + code = AuthorizationCode( + code=f"code_{secrets.token_hex(16)}", + client_id=client.client_id, + scopes=params.scopes or self._default_scopes, + expires_at=time.time() + 300, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + resource=params.resource, + ) + self.codes[code.code] = code + # `iss` is RFC 9207's authorization-response issuer identifier — an extra parameter many + # real authorization servers send. Including it on every success redirect proves the + # client tolerates unrecognized callback parameters (RFC 6749 §4.1.2 MUST) by virtue of + # every flow test passing unchanged. + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state, iss=self._issuer) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Mint an access token and a refresh token for a valid authorization code, then consume the code.""" + assert client.client_id is not None + access = self.mint_access_token( + client_id=client.client_id, scopes=authorization_code.scopes, resource=authorization_code.resource + ) + refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[refresh] = RefreshToken( + token=refresh, + client_id=client.client_id, + scopes=authorization_code.scopes, + ) + del self.codes[authorization_code.code] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(authorization_code.scopes), + refresh_token=refresh, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + if self._reject_all_tokens: + return None + return self.access_tokens.get(token) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str] + ) -> OAuthToken: + """Mint a new access token and rotate the refresh token, consuming the old one.""" + assert client.client_id is not None + if self._fail_next_refresh: + self._fail_next_refresh = False + raise TokenError(error="invalid_grant", error_description="refresh denied by harness") + access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + new_refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + del self.refresh_tokens[refresh_token.token] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(scopes), + refresh_token=new_refresh, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Not exercised by this suite; revocation is out of scope for the interaction tests.""" + raise NotImplementedError diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py new file mode 100644 index 0000000000..5cb4e92d86 --- /dev/null +++ b/tests/interaction/auth/test_as_handlers.py @@ -0,0 +1,300 @@ +"""Error-plane behaviour of the SDK's bundled OAuth authorization-server handlers. + +The end-to-end OAuth tests prove the handlers' happy paths; these tests drive the same +mounted authorization server directly with raw httpx so the assertions are the HTTP +semantics (status, redirect target, error body, headers) the OAuth RFCs mandate. Almost +every behaviour here is enforced by the SDK's own handlers; where the pinned output +deviates from the RFC, the manifest entry carries the divergence. +""" + +import base64 +import hashlib +import secrets +from collections.abc import AsyncIterator +from urllib.parse import parse_qs, urlsplit + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import ProviderTokenVerifier +from mcp.shared.auth import OAuthClientInformationFull +from tests.interaction._connect import mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import REDIRECT_URI, auth_settings, oauth_client_metadata +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def as_app() -> AsyncIterator[tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider]]: + """Co-host the SDK's authorization-server routes and yield a raw httpx client against them.""" + provider = InMemoryAuthorizationServerProvider() + settings = auth_settings() + async with mounted_app( + Server("guarded"), + auth=settings, + token_verifier=ProviderTokenVerifier(provider), + auth_server_provider=provider, + ) as (http, _): + yield http, provider + + +def _pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair the same way the SDK client does.""" + verifier = secrets.token_urlsafe(48)[:64] + challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") + return verifier, challenge + + +async def _register_client(http: httpx.AsyncClient) -> OAuthClientInformationFull: + """Dynamically register a client and return its full credentials.""" + response = await http.post("/register", content=oauth_client_metadata().model_dump_json()) + assert response.status_code == 201 + return OAuthClientInformationFull.model_validate_json(response.content) + + +async def _mint_code(http: httpx.AsyncClient) -> tuple[OAuthClientInformationFull, str, str]: + """Register a client, complete a valid authorize step, and return (client_info, code, verifier).""" + client_info = await _register_client(http) + assert client_info.client_id is not None + verifier, challenge = _pkce_pair() + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "s", + }, + follow_redirects=False, + ) + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + code = parse_qs(redirect.query)["code"][0] + return client_info, code, verifier + + +def _token_form(client_info: OAuthClientInformationFull, **overrides: str) -> dict[str, str]: + """Build the form body for an authorization-code token request, with the defaults a real client would send.""" + assert client_info.client_id is not None + assert client_info.client_secret is not None + form = { + "grant_type": "authorization_code", + "client_id": client_info.client_id, + "client_secret": client_info.client_secret, + "redirect_uri": REDIRECT_URI, + } + form.update(overrides) + return form + + +@requirement("hosting:auth:as:authorize-requires-pkce") +async def test_authorize_without_a_code_challenge_is_rejected_with_invalid_request( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request omitting `code_challenge` is redirected back with `error=invalid_request`. + + PKCE is mandatory: the bundled authorize handler models `code_challenge` as a required field, so + a code without a stored challenge can never be issued. That makes the PKCE-downgrade attack (a + token request carrying a verifier for a code minted without a challenge) structurally impossible + through these handlers, so no separate downgrade-guard test is needed. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "state": "abc", + }, + follow_redirects=False, + ) + + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + params = parse_qs(redirect.query) + assert params["error"] == ["invalid_request"] + assert params["state"] == ["abc"] + assert "code_challenge" in params["error_description"][0] + + +@requirement("hosting:auth:as:verifier-mismatch") +async def test_a_mismatched_code_verifier_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `code_verifier` does not hash to the stored challenge is rejected.""" + http, _ = as_app + client_info, code, _ = await _mint_code(http) + + response = await http.post("/token", data=_token_form(client_info, code=code, code_verifier="0" * 64)) + + assert response.status_code == 400 + assert response.json() == snapshot({"error": "invalid_grant", "error_description": "incorrect code_verifier"}) + + +@requirement("hosting:auth:as:code-single-use") +async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorization code can be exchanged exactly once; a second exchange is `invalid_grant`. + + The handler does not track used codes itself: it returns `invalid_grant` whenever the provider's + `load_authorization_code` returns None, and the in-memory provider deletes the code on first + exchange. The test proves the combination enforces single-use; a provider that did not consume + codes would not get this guarantee from the handler. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + first = await http.post("/token", data=form) + assert first.status_code == 200 + assert first.json()["token_type"] == "Bearer" + + second = await http.post("/token", data=form) + assert second.status_code == 400 + assert second.json() == snapshot( + {"error": "invalid_grant", "error_description": "authorization code does not exist"} + ) + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + + This is the security-critical half of redirect-URI binding: a code intercepted via redirect + substitution cannot be redeemed because the attacker cannot reproduce the original authorize + redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; + the SDK returns `invalid_request` (see the divergence on the requirement). The rejection + itself is the security property and is correct. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + + response = await http.post( + "/token", + data=_token_form(client_info, code=code, code_verifier=verifier, redirect_uri=f"{REDIRECT_URI}/different"), + ) + + assert response.status_code == 400 + assert response.json() == snapshot( + { + "error": "invalid_request", + "error_description": "redirect_uri did not match the one used when creating auth code", + } + ) + + +@requirement("hosting:auth:as:token-cache-headers") +async def test_token_responses_carry_cache_control_no_store( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Every token-endpoint response (success and error) carries `Cache-Control: no-store`.""" + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + success = await http.post("/token", data=form) + assert success.status_code == 200 + assert success.headers["cache-control"] == "no-store" + assert success.headers["pragma"] == "no-cache" + + failure = await http.post("/token", data=form) + assert failure.status_code == 400 + assert failure.headers["cache-control"] == "no-store" + assert failure.headers["pragma"] == "no-cache" + + +@requirement("hosting:auth:as:register-error-response") +async def test_registration_with_invalid_metadata_is_rejected_with_400( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Invalid client metadata at the registration endpoint returns 400 with an RFC 7591 error body.""" + http, _ = as_app + + malformed = await http.post("/register", json={"redirect_uris": ["not-a-url"]}) + assert malformed.status_code == 400 + assert malformed.json()["error"] == "invalid_client_metadata" + + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + + no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) + assert no_auth_code.status_code == 400 + assert no_auth_code.json() == snapshot( + {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + ) + + bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) + assert bad_scope.status_code == 400 + body = bad_scope.json() + assert body["error"] == "invalid_client_metadata" + # The description embeds a set difference whose ordering is not stable, so assert the prefix. + assert body["error_description"].startswith("Requested scopes are not valid: ") + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request naming an unregistered `redirect_uri` returns 400 without redirecting to it. + + The security property is that the authorization server never redirects to an unvalidated URI: + the response is a direct JSON error to the user agent, not a 302 to the attacker's host. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + _, challenge = _pkce_pair() + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": "http://127.0.0.1:8000/evil", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "location" not in response.headers + body = response.json() + assert body["error"] == "invalid_request" + assert "not registered" in body["error_description"] + + +@requirement("hosting:auth:as:redirect-uri-scheme") +async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled + registration handler does not enforce this and registers `http://evil.example/callback` + successfully. See the divergence on the requirement. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = ["http://evil.example/callback"] + + response = await http.post("/register", json=body) + + assert response.status_code == 201 + info = OAuthClientInformationFull.model_validate_json(response.content) + assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert info.client_id in provider.clients diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py new file mode 100644 index 0000000000..cb8524c097 --- /dev/null +++ b/tests/interaction/auth/test_authorize_token.py @@ -0,0 +1,399 @@ +"""Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. + +Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on +the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what +the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +snapshots each request at send time so the auth flow's in-place header mutation on retry never +affects what was captured for the first attempt. + +Tests #1/#2/#4/#5 share one `recorded_oauth_flow` fixture (one connect, several disjoint +assertions on its recording); the others connect fresh because each needs a different harness +configuration. +""" + +import base64 +import hashlib +import json +import re +from collections.abc import AsyncIterator +from dataclasses import dataclass +from urllib.parse import parse_qsl, quote, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + HeadlessOAuth, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + first_challenge_shim, + record_requests, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict (one value per key).""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + """Filter recorded requests by method and exact path.""" + return [r for r in recorded if r.method == method and r.path == path] + + +@dataclass +class RecordedFlow: + """One completed OAuth connect: every recorded request, plus the parsed authorize URL params.""" + + requests: list[RecordedRequest] + authorize_url: str + + @property + def authorize(self) -> dict[str, str]: + return authorize_params(self.authorize_url) + + @property + def token_request(self) -> RecordedRequest: + token_posts = find(self.requests, "POST", "/token") + assert len(token_posts) == 1 + return token_posts[0] + + +@pytest.fixture +async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: + """Run one full OAuth connect with default configuration and yield its recorded wire traffic. + + `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's + SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + stays at `["mcp"]` so the issued token still passes the bearer middleware. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, settings=settings, on_request=on_request) as ( + client, + headless, + ): + await client.list_tools() + + assert headless.authorize_url is not None + yield RecordedFlow(requests=recorded, authorize_url=headless.authorize_url) + + +@requirement("client-auth:pkce:s256") +@requirement("client-auth:resource-parameter") +@requirement("client-auth:authorize:offline-access-consent") +async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every spec-mandated parameter appears on the authorize URL with the right value. + + The full key set is snapshotted so a parameter added or dropped fails the test. The + `code_challenge` length bound is the RFC 7636 §4.2 grammar; an S256 challenge is in + practice always 43 characters, so the upper bound is never approached. + """ + params = recorded_oauth_flow.authorize + + assert sorted(params) == snapshot( + [ + "client_id", + "code_challenge", + "code_challenge_method", + "prompt", + "redirect_uri", + "resource", + "response_type", + "scope", + "state", + ] + ) + assert params["response_type"] == "code" + assert params["code_challenge_method"] == "S256" + assert 43 <= len(params["code_challenge"]) <= 128 + # The exact resource value depends on canonical-URI normalisation (a spec ambiguity); pin + # the stable prefix so the test does not lock in a trailing-slash decision. + assert params["resource"].startswith(BASE_URL) + assert params["state"] != "" + + assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) + assert params["prompt"] == "consent" + + +@requirement("client-auth:pkce:s256") +async def test_the_code_verifier_on_the_token_request_hashes_to_the_code_challenge( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The PKCE verifier sent on /token is the S256 pre-image of the challenge sent on /authorize. + + The verifier is also checked against RFC 7636 §4.1's length and `unreserved` charset. + """ + challenge = recorded_oauth_flow.authorize["code_challenge"] + verifier = form_body(recorded_oauth_flow.token_request)["code_verifier"] + + assert re.fullmatch(r"[A-Za-z0-9._~-]{43,128}", verifier) + assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge + + +@requirement("client-auth:state:verify") +async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: + """A callback whose state does not match the value sent on /authorize raises and stops the flow. + + The auth flow runs inside the streamable-HTTP client's task group, so the `OAuthFlowError` + reaches the test wrapped in nested single-element exception groups; `pytest.RaisesGroup` + asserts the leaf type and the SDK-authored message prefix (the full message embeds two + random tokens). + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth(state_override="wrong-state") + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True + ): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() + + +@requirement("client-auth:resource-parameter") +async def test_the_authorization_code_token_request_carries_grant_type_code_redirect_and_resource( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The /token form body has exactly the auth-code grant fields, with redirect_uri and resource matching /authorize. + + `client_secret` is present because the SDK's dynamic-registration handler issues a secret + and the client defaults to `client_secret_post`. + """ + token_req = recorded_oauth_flow.token_request + body = form_body(token_req) + + assert sorted(body) == snapshot( + ["client_id", "client_secret", "code", "code_verifier", "grant_type", "redirect_uri", "resource"] + ) + assert body["grant_type"] == "authorization_code" + assert body["code"] != "" + assert body["redirect_uri"] == recorded_oauth_flow.authorize["redirect_uri"] + assert body["resource"] == recorded_oauth_flow.authorize["resource"] + assert token_req.headers["content-type"] == "application/x-www-form-urlencoded" + + +@requirement("client-auth:bearer-header:every-request") +async def test_every_mcp_request_after_auth_carries_the_bearer_header_and_never_a_query_token( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every MCP request after the flow has `Authorization: Bearer ...` and never `?access_token=`. + + The first /mcp POST is the unauthenticated trigger and is asserted to carry no Authorization + header; that assertion is only meaningful because the recording snapshots requests at send + time (the SDK mutates the same request object in place for the retry). + """ + mcp_posts = find(recorded_oauth_flow.requests, "POST", "/mcp") + assert len(mcp_posts) >= 3 + + assert "authorization" not in mcp_posts[0].headers + for r in mcp_posts[1:]: + assert r.headers["authorization"].startswith("Bearer ") + assert r.headers["authorization"] != "Bearer " + assert "access_token" not in dict(r.url.params) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_basic() -> None: + """A `client_secret_basic` client sends URL-encoded credentials in HTTP Basic, not the body. + + Credentials are URL-encoded before base64 per RFC 6749 §2.3.1; the secret contains `/` so + the encoding is observable. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="cid", + client_secret="s/cret", + token_endpoint_auth_method="client_secret_basic", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage = InMemoryTokenStorage(client_info=client_info) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + [token_req] = find(recorded, "POST", "/token") + + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == f"{quote('cid', safe='')}:{quote('s/cret', safe='')}" + assert "client_secret" not in form_body(token_req) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_advertised_methods() -> None: + """The token-endpoint auth method comes from the registered client info, not from AS metadata. + + The shim serves AS metadata advertising only `client_secret_basic`; the client dynamically + registers and the SDK's registration handler issues `client_secret_post`. The client uses + `client_secret_post` (secret in the body, no Basic header) because the SDK reads the + registered `token_endpoint_auth_method`, not `token_endpoint_auth_methods_supported`. Other + SDKs (TypeScript, Go) do consult the AS metadata; this test pins where the python SDK's + selection point lives. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve), on_request=on_request + ) as (client, _): + await client.list_tools() + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content).get("token_endpoint_auth_method") is None + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert "client_secret" in body + assert body["client_secret"] != "" + assert "authorization" not in token_req.headers + + +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: + """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. + + The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence + on `hosting:auth:scope-403`), so the test supplies the first 401 itself via + `first_challenge_shim` and disables token verification so the post-auth retry succeeds + regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors + `required_scopes`); the challenge says `from-header`; the authorize URL must carry + `from-header`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) + challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + verify_tokens=False, + app_shim=first_challenge_shim(challenge), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-header" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-header" + + +@requirement("client-auth:pkce:refuse-if-unsupported") +async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: + """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. + + The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow + completes. See the divergence on the requirement. + """ + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + ) + assert override.code_challenge_methods_supported is None + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) + ) as (client, headless): + result = await client.list_tools() + + assert headless.authorize_url is not None + params = authorize_params(headless.authorize_url) + assert params["code_challenge_method"] == "S256" + assert params["code_challenge"] != "" + assert result.tools[0].name == "echo" + + +@requirement("client-auth:authorize:error-surfaces") +async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_token_request() -> None: + """An `error=` redirect from /authorize aborts the flow with no /token request issued. + + The SDK's callback contract is `() -> (code, state)` with no error form, so the failure is + observed as an empty code reaching the SDK and `OAuthFlowError("No authorization code + received")` being raised. The actual `error` value from the redirect is not surfaced to the + caller; that gap is noted in the manifest. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(deny_authorize=True) + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^No authorization code received$"), flatten_subgroups=True + ): + await connect_with_oauth(server, provider=provider, headless=headless, on_request=on_request).__aenter__() + + assert headless.error == "access_denied" + assert find(recorded, "POST", "/token") == [] diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py new file mode 100644 index 0000000000..341a8e0db9 --- /dev/null +++ b/tests/interaction/auth/test_bearer.py @@ -0,0 +1,189 @@ +"""Resource-server bearer-token gate: status codes and `WWW-Authenticate` for each token shape. + +These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` +seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since +every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, +the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP +endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +""" + +import time +from collections.abc import AsyncIterator + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import AccessToken +from mcp.types import JSONRPCResponse +from tests.interaction._connect import base_headers, initialize_body, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import StaticTokenVerifier, auth_settings + +pytestmark = pytest.mark.anyio + +REQUIRED_SCOPE = "mcp:read" +RESOURCE_METADATA_URL = "http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + +_FUTURE = int(time.time()) + 3600 +_PAST = int(time.time()) - 3600 + +TOKENS = { + "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), + "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), + "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-wrong-aud": AccessToken( + token="tok-wrong-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="https://other.example/mcp", + ), +} + + +@pytest.fixture +async def protected() -> AsyncIterator[httpx.AsyncClient]: + """A bearer-gated streamable-HTTP app (resource server only) on the in-process bridge.""" + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + yield http + + +async def post_mcp( + http: httpx.AsyncClient, *, bearer: str | None = None, query: dict[str, str] | None = None +) -> httpx.Response: + """POST an initialize body to `/mcp`, optionally with a bearer token and/or a query string.""" + headers = base_headers() + if bearer is not None: + headers["authorization"] = f"Bearer {bearer}" + return await http.post("/mcp", headers=headers, params=query, json=initialize_body()) + + +def parse_www_authenticate(value: str) -> dict[str, str]: + """Parse a `Bearer k="v", k="v"` challenge into a dict. + + The SDK emits each parameter exactly once, comma-space separated, with double-quoted + values that contain no quotes themselves; this helper relies on that and would fail + visibly if the format changed. + """ + scheme, _, params = value.partition(" ") + assert scheme == "Bearer" + return {key: quoted.strip('"') for key, _, quoted in (pair.partition("=") for pair in params.split(", "))} + + +@requirement("hosting:auth:missing-401") +async def test_a_request_with_no_authorization_header_is_challenged_with_resource_metadata( + protected: httpx.AsyncClient, +) -> None: + """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. + + The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and + expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The + spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the + no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence + on this requirement. Asserting the dict equals an exact key set also pins that no parameter + appears twice. + """ + response = await post_mcp(protected) + + assert response.status_code == 401 + assert response.headers["www-authenticate"] == snapshot( + 'Bearer error="invalid_token", error_description="Authentication required", ' + 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + ) + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + + +@requirement("hosting:auth:invalid-401") +async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token the verifier does not recognize is answered 401 `invalid_token`. + + The challenge is identical to the no-header case (the backend returns `None` for both); the + missing `scope` parameter is the recorded divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-unknown") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:expired-401") +async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> None: + """A token whose `expires_at` is in the past is answered 401 `invalid_token`. + + The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete + past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded + divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-expired") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + + +@requirement("hosting:auth:scope-403") +async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( + protected: httpx.AsyncClient, +) -> None: + """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + + The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` + naming the required scope; the SDK never emits it, recorded as the divergence on this + requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is + a resource-server/client asymmetry. + """ + response = await post_mcp(protected, bearer="tok-noscope") + + assert response.status_code == 403 + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert parsed == { + "error": "insufficient_scope", + "error_description": f"Required scope: {REQUIRED_SCOPE}", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert "scope" not in parsed + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is accepted. + + The spec mandates the resource server validate the token's audience; the bearer backend + never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint + serves it. This pins current behaviour with the divergence recorded on the requirement. + """ + response = await post_mcp(protected, bearer="tok-wrong-aud") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response + # out of the buffered text to prove the MCP endpoint actually answered the initialize request. + [data] = [line.removeprefix("data: ") for line in response.text.splitlines() if line.startswith("data: ")] + assert "protocolVersion" in JSONRPCResponse.model_validate_json(data).result + + +@requirement("hosting:auth:query-token-ignored") +async def test_an_access_token_in_the_query_string_is_not_accepted(protected: httpx.AsyncClient) -> None: + """A valid token presented in the URI query string is treated as no authentication. + + The bearer backend reads only the `Authorization` header, so `?access_token=...` is never + consulted; the request is treated as unauthenticated and answered 401. This satisfies, by + absence, the security best-practice that resource servers must not accept query-string + tokens. + """ + response = await post_mcp(protected, query={"access_token": "tok-valid"}) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py new file mode 100644 index 0000000000..68c33c8a2d --- /dev/null +++ b/tests/interaction/auth/test_discovery.py @@ -0,0 +1,333 @@ +"""Protected-resource and authorization-server metadata discovery, end to end. + +Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the +recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail `Client` cannot observe directly but the recording can. Tests that need a metadata +endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving +the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. + +The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their +assertions are the metadata response bodies and headers, which `Client` does not surface. +""" + +import json + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError, OAuthRegistrationError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + RecordedRequest, + auth_settings, + connect_with_oauth, + metadata_body, + record_requests, + shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH_SUFFIXED = "/.well-known/oauth-protected-resource/mcp" +PRM_ROOT = "/.well-known/oauth-protected-resource" +ASM_ROOT = "/.well-known/oauth-authorization-server" +OIDC_ROOT = "/.well-known/openid-configuration" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + + +def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: + """Return the well-known GET paths in recorded order, ignoring everything else.""" + return [r.path for r in recorded if r.method == "GET" and "/.well-known/" in r.path] + + +def real_asm() -> OAuthMetadata: + """Build an authorization-server metadata document pointing at the real co-hosted endpoints.""" + return OAuthMetadata( + issuer=AnyHttpUrl(BASE_URL), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticate() -> None: + """The first protected-resource probe is the URL the 401's `WWW-Authenticate` header supplied. + + With co-hosted defaults the header carries the path-suffixed well-known URL; the client + fetches that one first and, because it succeeds, never falls back. The single-probe + sequence proves priority 1. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): + await client.list_tools() + + assert discovery_gets(recorded) == snapshot([PRM_PATH_SUFFIXED, ASM_ROOT]) + assert (recorded[0].method, recorded[0].path) == ("POST", "/mcp") + assert (recorded[1].method, recorded[1].path) == ("GET", PRM_PATH_SUFFIXED) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> None: + """When the path-suffixed PRM well-known 404s, the client falls back to the root well-known. + + The exact GET count is not asserted: the WWW-Authenticate URL equals the path well-known + here, so the SDK probes it twice (once as priority 1, once as priority 2) before reaching + root. Asserting "path before root, root reached, then the flow proceeds" pins the spec + invariant; the duplicate probe is an implementation detail. The served PRM body carries an + unrecognized field to prove the client's parser ignores unknown members (RFC 9728 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim( + not_found=frozenset({PRM_PATH_SUFFIXED}), + serve={PRM_ROOT: metadata_body(prm, x_unknown_extension="ignored")}, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known.index(PRM_PATH_SUFFIXED) < well_known.index(PRM_ROOT) + assert any(r.path == "/authorize" for r in recorded) + + +@requirement("client-auth:prm-discovery:no-prm-fallback") +async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_the_server_origin() -> None: + """When every protected-resource metadata probe 404s, the client falls back to the legacy path. + + The legacy 2025-03-26 behaviour: with no PRM document available, treat the MCP server's + origin as the authorization server and fetch its `/.well-known/oauth-authorization-server` + directly. The real co-hosted ASM endpoint is at exactly that location, so the flow completes. + The recorded sequence shows both PRM well-known paths probed (and failed) before ASM_ROOT. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + result = await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known[-1] == ASM_ROOT + assert all(well_known.index(prm) < well_known.index(ASM_ROOT) for prm in (PRM_PATH_SUFFIXED, PRM_ROOT)) + assert result.tools[0].name == "probe" + + +@requirement("client-auth:dcr:registration-error-surfaces") +async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_error() -> None: + """A 400 from `/register` surfaces as `OAuthRegistrationError` carrying the server's body. + + The shim makes `/register` return RFC 7591's `invalid_client_metadata`; the SDK reads the + body and raises with the status and text in the message, before any authorize or token + request is made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() + app_shim = shim(serve={"/register": (400, error_body)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthRegistrationError, match=r"^Registration failed: 400 .*invalid_client_metadata"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:prm-resource-mismatch") +async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() -> None: + """A PRM document whose `resource` does not cover the server URL aborts the flow. + + The shim serves PRM at the URL the WWW-Authenticate header supplies, but with a `resource` + on a different path; `check_resource_allowed` rejects it and `OAuthFlowError` is raised + before any authorize or token request is made. The error reaches the test wrapped in nested + single-element exception groups by the streamable-HTTP client's task group. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim(serve={PRM_PATH_SUFFIXED: metadata_body(prm)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^Protected resource .* does not match expected"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:as-metadata-discovery:priority-order") +@pytest.mark.parametrize( + ("authorization_server", "not_found", "serve_at", "expected_order"), + [ + pytest.param( + f"{BASE_URL}/", + frozenset({ASM_ROOT}), + OIDC_ROOT, + [ASM_ROOT, OIDC_ROOT], + id="root-issuer", + ), + pytest.param( + f"{BASE_URL}/tenant", + frozenset({f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant"}), + "/tenant/.well-known/openid-configuration", + [f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant", "/tenant/.well-known/openid-configuration"], + id="path-issuer", + ), + ], +) +async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( + authorization_server: str, not_found: frozenset[str], serve_at: str, expected_order: list[str] +) -> None: + """Authorization-server metadata is fetched at the spec's endpoints in the spec's order. + + The shim 404s every endpoint before the last so the recording proves each probe and its + position. For an issuer URL with no path the order is OAuth root then OIDC root; for an + issuer URL with a path component it is OAuth path-inserted, OIDC path-inserted, then OIDC + path-appended (the spec's three-endpoint MUST). The path-issuer case is driven by serving + a PRM whose `authorization_servers` carries the path; the SDK's own AS routes stay at root + (the served body points at the real `/authorize` and `/token`). The served bodies carry an + unrecognized field to prove the client's parser ignores unknown members (RFC 8414 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] + ) + app_shim = shim( + not_found=not_found, + serve={ + PRM_PATH_SUFFIXED: metadata_body(prm), + serve_at: metadata_body(real_asm(), x_unknown_extension="ignored"), + }, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + assert discovery_gets(recorded) == [PRM_PATH_SUFFIXED, *expected_order] + + +@requirement("hosting:auth:metadata-endpoints") +@requirement("hosting:auth:prm:authorization-servers-field") +async def test_the_prm_endpoint_serves_the_resource_url_and_at_least_one_authorization_server() -> None: + """The protected-resource metadata document the SDK serves identifies the resource and an authorization server. + + Also asserts the response is `application/json` (RFC 9728 §3.2) and that fields the SDK has + no value for are absent rather than null (`PydanticJSONResponse` serializes with + `exclude_none=True`, satisfying RFC 9728 §3.2's omit-zero-value rule). + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(PRM_PATH_SUFFIXED) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + document = json.loads(response.content) + assert "resource_documentation" not in document + assert "scopes_supported" in document + + metadata = ProtectedResourceMetadata.model_validate(document) + assert str(metadata.resource).rstrip("/") == f"{BASE_URL}/mcp" + assert len(metadata.authorization_servers) >= 1 + assert metadata.bearer_methods_supported == ["header"] + + +@requirement("hosting:auth:as-router") +async def test_as_metadata_advertises_authorize_token_registration_and_s256() -> None: + """The authorization-server metadata document the SDK serves names the required endpoints and S256.""" + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(ASM_ROOT) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + metadata = OAuthMetadata.model_validate_json(response.content) + assert str(metadata.issuer).rstrip("/") == BASE_URL + assert str(metadata.authorization_endpoint) == f"{BASE_URL}/authorize" + assert str(metadata.token_endpoint) == f"{BASE_URL}/token" + assert str(metadata.registration_endpoint) == f"{BASE_URL}/register" + assert metadata.response_types_supported == ["code"] + assert metadata.code_challenge_methods_supported is not None + assert "S256" in metadata.code_challenge_methods_supported + + +@requirement("client-auth:as-metadata-discovery:issuer-validation") +async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_proceeds() -> None: + """Authorization-server metadata whose `issuer` does not match the discovery URL is accepted. + + RFC 8414 §3.3 requires the client to reject the document; the SDK parses and uses it + without comparing `issuer` to the URL it was fetched from. See the divergence on the + requirement. The served body carries an unrecognized field as a fold-in proof of + unknown-field tolerance. + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + metadata = real_asm() + metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") + app_shim = shim(serve={ASM_ROOT: metadata_body(metadata, x_unknown_extension="ignored")}) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "probe" diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py new file mode 100644 index 0000000000..968fc5f980 --- /dev/null +++ b/tests/interaction/auth/test_flow.py @@ -0,0 +1,239 @@ +"""End-to-end OAuth authorization-code flow against the SDK's own server, fully in process. + +Auth is HTTP-only so these tests are not transport-parametrized; each connects via +`connect_with_oauth`, which co-hosts the SDK's authorization server, protected-resource +metadata, and bearer-gated MCP endpoint on one bridge-backed Starlette app and drives the +whole flow through one `httpx.AsyncClient` carrying the SDK's `OAuthClientProvider`. The +authorize redirect completes headlessly through the same bridge, so every request the flow +makes is observable via `on_request`. +""" + +import json +from collections import Counter +from urllib.parse import parse_qs, urlsplit + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import AnyUrl + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.shared.auth import OAuthClientInformationFull +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + auth_settings, + connect_with_oauth, + oauth_client_metadata, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + + +@requirement("flow:oauth:authorization-code-roundtrip") +@requirement("client-auth:401-triggers-flow") +@requirement("hosting:auth:missing-401") +async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow_connects() -> None: + """Connecting to a bearer-gated server walks the full authorization-code flow and succeeds. + + Three requirements are proven by one connect: the flow runs end to end (authorization-code + roundtrip), it was triggered by a 401 on the first MCP request (401-triggers-flow), and + that 401 carried `resource_metadata` in `WWW-Authenticate` for discovery (missing-401). + The flagship test pins the recorded request sequence so the discovery → registration → + authorize → token → retry order is asserted explicitly. + + Steps the SDK is expected to perform: + 1. POST /mcp without a token → 401 with `WWW-Authenticate: Bearer resource_metadata=...`. + 2. GET the protected-resource metadata. + 3. GET the authorization-server metadata. + 4. POST /register (dynamic client registration). + 5. GET /authorize → 302 with code+state (completed by the headless redirect). + 6. POST /token (authorization-code exchange). + 7. Retry POST /mcp with `Authorization: Bearer ` → succeeds. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + headless, + ): + result = await client.list_tools() + + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert headless.authorize_url is not None + + paths = [(r.method, r.url.path) for r in requests] + assert Counter(paths) == snapshot( + Counter( + { + ("POST", "/mcp"): 4, + ("GET", "/.well-known/oauth-protected-resource/mcp"): 1, + ("GET", "/.well-known/oauth-authorization-server"): 1, + ("POST", "/register"): 1, + ("GET", "/authorize"): 1, + ("POST", "/token"): 1, + ("GET", "/mcp"): 1, + ("DELETE", "/mcp"): 1, + } + ) + ) + + assert (requests[0].method, requests[0].url.path) == ("POST", "/mcp") + # The recorded Request objects are live references: the auth flow mutates the original + # request's headers in place when it adds the bearer token for the retry, so the first + # entry's headers cannot be used to assert "no Authorization on the first attempt". The + # path multiset above proving discovery happened is the evidence the first attempt was 401. + + # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a + # spec requirement on discovery requests). + prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + + authorize = parse_qs(urlsplit(headless.authorize_url).query) + assert authorize["response_type"] == ["code"] + assert authorize["code_challenge_method"] == ["S256"] + assert authorize["client_id"][0] in provider.clients + + assert storage.tokens is not None + bearer = f"Bearer {storage.tokens.access_token}" + authed_mcp = [r for r in requests if r.url.path == "/mcp" and r.headers.get("authorization") == bearer] + assert len(authed_mcp) > 0 + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("hosting:auth:authinfo-propagates") +async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: + """A tool handler reads the request's access token through `get_access_token()`.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + token = get_access_token() + assert token is not None + return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + + server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) + provider = InMemoryAuthorizationServerProvider() + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider) as (client, _): + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + + +@requirement("client-auth:pre-registration") +async def test_a_preregistered_client_skips_registration() -> None: + """A client whose storage already holds client info uses it instead of registering. + + The provider holds the same registration server-side so the authorize and token steps + accept it; the recorded requests prove no `/register` call was made. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="preregistered", + client_secret="s3cret", + token_endpoint_auth_method="client_secret_post", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage.client_info = client_info + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + _, + ): + await client.list_tools() + + assert [r.url.path for r in requests].count("/register") == 0 + assert list(provider.clients) == ["preregistered"] + + +@requirement("client-auth:dcr") +async def test_the_dcr_request_carries_the_client_metadata() -> None: + """Dynamic registration sends the client's metadata and persists what the server issued. + + The body of the recorded `/register` POST carries the metadata the test supplied (with the + scope filled in from server discovery), and the server's issued client_id and secret are + persisted to storage and held by the provider. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_metadata = oauth_client_metadata() + client_metadata.software_id = "interaction-test-suite" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, storage=storage, client_metadata=client_metadata, on_request=requests.append + ) as (client, _): + await client.list_tools() + + register = next(r for r in requests if r.url.path == "/register") + assert register.headers["content-type"] == "application/json" + body = json.loads(register.content) + assert body == snapshot( + { + "redirect_uris": ["http://127.0.0.1:8000/oauth/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "mcp", + "client_name": "interaction-suite", + "software_id": "interaction-test-suite", + } + ) + + assert storage.client_info is not None + assert storage.client_info.client_id is not None + assert storage.client_info.client_secret is not None + assert list(provider.clients) == [storage.client_info.client_id] + + +async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app() -> None: + """Harness self-test: `shimmed_app` serves canned bodies, 404s, and forwards everything else. + + Wraps a real auth-hosting Starlette app so the forward path is exercised against the SDK's + own routing; provided here so the discovery tests can rely on the shim without each adding + their own contract test. + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + served = await http.get("/override") + assert served.status_code == 200 + assert served.headers["content-type"] == "application/json" + assert served.json() == {"shimmed": True} + + assert (await http.get("/missing")).status_code == 404 + + forwarded = await http.get("/.well-known/oauth-authorization-server") + assert forwarded.status_code == 200 + assert forwarded.json()["issuer"] == "http://127.0.0.1:8000/" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py new file mode 100644 index 0000000000..aa552ae8a6 --- /dev/null +++ b/tests/interaction/auth/test_lifecycle.py @@ -0,0 +1,445 @@ +"""Token lifecycle, step-up, and registration-variant flows of the SDK's OAuth client. + +Every test connects end to end via `connect_with_oauth`; the assertions are recording-first +(the recorded request sequence is asserted before, or independently of, the call result), so a +surprise in the refresh or step-up paths produces a readable diff of what fired rather than an +opaque failure. The provider knobs that drive each scenario are documented per test. +""" + +import base64 +from collections import Counter +from urllib.parse import parse_qsl, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import MCPError, types +from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + m2m_token_shim, + metadata_body, + record_requests, + shim, + step_up_shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" +CIMD_URL = "https://client.example/.well-known/mcp-client" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict.""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + return [r for r in recorded if r.method == method and r.path == path] + + +def path_counts(recorded: list[RecordedRequest]) -> Counter[tuple[str, str]]: + return Counter((r.method, r.path) for r in recorded) + + +def cimd_supported_metadata() -> bytes: + """AS metadata advertising `client_id_metadata_document_supported: true` (the SDK server never sets it).""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, + ) + return metadata_body(metadata) + + +def seeded_client(provider: InMemoryAuthorizationServerProvider, **kwargs: object) -> OAuthClientInformationFull: + """Register a client with the provider and return its info, for pre-registration and CIMD scenarios.""" + base: dict[str, object] = { + "client_id": "preregistered", + "token_endpoint_auth_method": "none", + "redirect_uris": [AnyUrl(REDIRECT_URI)], + "grant_types": ["authorization_code", "refresh_token"], + "scope": "mcp", + } + base.update(kwargs) + info = OAuthClientInformationFull.model_validate(base) + assert info.client_id is not None + provider.clients[info.client_id] = info + return info + + +@requirement("client-auth:refresh:transparent") +async def test_an_expired_access_token_is_transparently_refreshed_before_the_next_request() -> None: + """An access token the client considers expired is refreshed and the new bearer is used. + + The provider tells the client `expires_in=-3600` for the first token while keeping the + server-side `expires_at` in the future, so the connect's retry succeeds and the next + request finds the token expired and refreshes. The recorded requests prove exactly one + `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used + after the refresh is the second access token, which is the one persisted to storage. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + bodies = [form_body(r) for r in token_posts] + assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) + + refresh_body = bodies[1] + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert refresh_body["refresh_token"].startswith("refresh_") + assert refresh_body["resource"].startswith(BASE_URL) + + bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} + assert len(bearers) == 2 + assert storage.tokens is not None + assert f"Bearer {storage.tokens.access_token}" in bearers + assert storage.tokens.expires_in == 3600 + + +@requirement("client-auth:403-scope-upgrade") +async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challenged_scope() -> None: + """A 403 `insufficient_scope` challenge is answered by one re-authorize with the challenge's scope. + + The shim 403s the second authenticated `/mcp` POST (the `notifications/initialized` request, + which reaches the auth flow's step-up handler; the first authenticated POST is the post-401 + retry, after which the generator ends without inspecting the response). The challenge names a + wider scope; step-up reuses cached metadata and the existing client registration, + re-authorizes with the new scope, and the connect completes. The client is pre-registered + with both scopes so the server's authorize handler accepts the wider second request. One + re-authorize, one retry; the spec's SHOULD-retry-limit ("a few") is not enforced. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) + challenge = 'Bearer error="insufficient_scope", scope="mcp write"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + settings=settings, + app_shim=step_up_shim(challenge), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + assert len(headless.authorize_urls) == 2 + assert authorize_params(headless.authorize_urls[0])["scope"] == "mcp" + assert authorize_params(headless.authorize_urls[1])["scope"] == "mcp write" + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 0 + assert counts[("GET", "/authorize")] == 2 + assert counts[("POST", "/token")] == 2 + + +@requirement("client-auth:401-after-auth-throws") +async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_looping() -> None: + """A 401 on the post-auth retry surfaces as an error rather than re-entering discovery. + + The provider rejects every token at verification, so the full flow runs once and the retry + is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the + transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, + registration, authorize, and token each ran exactly once: no loop. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) + server = Server("guarded", on_list_tools=list_tools) + + def is_internal_error(error: MCPError) -> bool: + return error.error.code == INTERNAL_ERROR + + with anyio.fail_after(5): + with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 1 + assert counts[("POST", "/token")] == 1 + assert counts[("POST", "/mcp")] == 2 + + +@requirement("client-auth:cimd") +async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_url_is_supplied() -> None: + """A client-ID metadata-document URL is used as `client_id` instead of registering. + + AS metadata is shimmed to advertise `client_id_metadata_document_supported: true`; the + provider is pre-seeded so the server's authorize and token handlers accept the URL as a + client_id (the SDK server has no CIMD-aware client lookup of its own). The recorded + requests prove no `/register` call, the authorize URL's `client_id` is the CIMD URL, the + token request uses `token_endpoint_auth_method=none`, and storage persists the URL as + `client_id`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + seeded_client(provider, client_id=CIMD_URL) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=shim(serve={ASM_PATH: cimd_supported_metadata()}), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["client_id"] == CIMD_URL + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body["client_id"] == CIMD_URL + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + assert storage.client_info is not None + assert storage.client_info.client_id == CIMD_URL + assert storage.client_info.token_endpoint_auth_method == "none" + + +@requirement("client-auth:invalid-grant-clears-tokens") +async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow() -> None: + """A non-200 refresh response clears the in-memory tokens and the flow re-runs from discovery. + + The first token is reported expired so the next request refreshes; the provider denies the + refresh once with `invalid_grant`, the auth flow clears its tokens, the unauthenticated + request 401s, and discovery, authorize, and token run again. The original registration is + preserved (`client_info` is not cleared). The SDK clears tokens on any non-200 refresh + response, not specifically `error=invalid_grant`; `source="sdk"` so this is a precision + note rather than a divergence. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + assert [form_body(r)["grant_type"] for r in token_posts] == snapshot( + ["authorization_code", "refresh_token", "authorization_code"] + ) + + counts = path_counts(recorded) + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 2 + assert counts[("GET", PRM_PATH)] == 2 + assert counts[("GET", ASM_PATH)] == 2 + + assert storage.client_info is not None + assert storage.tokens is not None + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("client-auth:client-credentials") +async def test_client_credentials_provider_obtains_a_token_without_an_authorize_step() -> None: + """The client-credentials provider connects with no authorize step and a `client_credentials` grant. + + The SDK server's `TokenHandler` does not route `client_credentials`, so the harness shim + handles it (the shim is harness; the SDK-under-test is the client provider). The recorded + `/token` body proves the grant type, scope, resource indicator, and HTTP-Basic client + authentication; no `/authorize` or `/register` request was made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + auth = ClientCredentialsOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-client", + client_secret="m2m-secret", + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert headless.authorize_url is None + assert find(recorded, "GET", "/authorize") == [] + assert find(recorded, "POST", "/register") == [] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + {"grant_type": "client_credentials", "resource": "http://127.0.0.1:8000/mcp", "scope": "mcp"} + ) + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == "m2m-client:m2m-secret" + + +@requirement("client-auth:private-key-jwt") +async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_assertion() -> None: + """The private-key-JWT provider sends a `client_assertion` on the token request, with the issuer as audience. + + The assertion provider is a closure that records the audience it was called with and returns + a fixed opaque value (the JWT contents are not the SDK's concern here); the test asserts the + `client_assertion`/`client_assertion_type` form fields and that the audience matches the AS + metadata's issuer. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + audiences: list[str] = [] + + async def assertion_provider(audience: str) -> str: + audiences.append(audience) + return "header.payload.sig" + + auth = PrivateKeyJWTOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-jwt-client", + assertion_provider=assertion_provider, + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert audiences == [f"{BASE_URL}/"] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + { + "grant_type": "client_credentials", + "client_assertion": "header.payload.sig", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": "http://127.0.0.1:8000/mcp", + "scope": "mcp", + } + ) + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + +@pytest.mark.parametrize( + ("case", "preseed_storage", "advertise_cimd"), + [("cimd_unsupported_falls_through_to_dcr", False, False), ("preregistered_beats_cimd", True, True)], + ids=["cimd_unsupported_falls_through_to_dcr", "preregistered_beats_cimd"], +) +@requirement("client-auth:cimd") +async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( + case: str, preseed_storage: bool, advertise_cimd: bool +) -> None: + """The client picks pre-registration over CIMD over DCR, falling through when each is unavailable. + + Two priority edges are exercised: with a CIMD URL configured but no AS support, DCR runs and + the registered `client_id` is used; with a CIMD URL configured and AS support but a + pre-registered client in storage, the stored `client_id` is used and neither CIMD nor DCR + runs. (The positive CIMD case and pre-registration over DCR are covered by their own tests.) + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + storage = InMemoryTokenStorage() + + expected_client_id: str + if preseed_storage: + info = seeded_client(provider) + storage.client_info = info + assert info.client_id is not None + expected_client_id = info.client_id + else: + expected_client_id = "" + + app_shim = shim(serve={ASM_PATH: cimd_supported_metadata()}) if advertise_cimd else None + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=app_shim, + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + chosen_client_id = authorize_params(headless.authorize_url)["client_id"] + assert chosen_client_id != CIMD_URL + + if case == "cimd_unsupported_falls_through_to_dcr": + assert len(find(recorded, "POST", "/register")) == 1 + assert chosen_client_id in provider.clients + else: + assert find(recorded, "POST", "/register") == [] + assert chosen_client_id == expected_client_id diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..c2ace45077 --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,23 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..6f1454e58a --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,234 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("protocol:cancel:server-survives") +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:cancel:unknown-id-ignored") +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + ClientSession(client_read, client_write, message_handler=message_handler) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..6a35404df3 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,131 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:prompt-arg") +@requirement("completion:result-shape") +async def test_complete_prompt_argument(connect: Connect) -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:resource-template-arg") +async def test_complete_resource_template_variable(connect: Connect) -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:context-arguments") +async def test_complete_receives_context_arguments(connect: Connect) -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with connect(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..b8edf601d0 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,662 @@ +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. + +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:action:decline") +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:action:cancel") +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:url:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. The completion notification carries ``related_request_id`` so over + streamable HTTP it rides the tool call's own stream and reaches the client before the call + returns; the same ordering already holds on in-memory and SSE transports. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8d96582341 --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,203 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + EmptyResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server( + "gatekeeper", + on_list_tools=_list_tools("read_files"), + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + ) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..91adbf5611 --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,384 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API. + +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + CallToolResult, + ClientCapabilities, + CompletionsCapability, + EmptyResult, + ErrorData, + Icon, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:basic") +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info(connect: Connect) -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with connect(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions(connect: Connect) -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with connect(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with connect(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with connect(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with connect(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info(connect: Connect) -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ) + if value is not None + ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:version:reject-unsupported") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..a2f85eeacf --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,136 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +``send_*_list_changed`` does not take a ``related_request_id``, so over streamable HTTP the +notification routes to the standalone GET stream and is not guaranteed to arrive before the tool +result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern +as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. +The collector still records every message it receives, so the snapshot also proves nothing else +was delivered. + +The servers register the parent capability (resources/prompts) so that part of the spec's +precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` +is not threaded through any of the suite's connection paths. The tests therefore rely on the +recorded ``lifecycle:capability:server-not-advertised`` divergence and will need updating +alongside the fix that introduces capability gating. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + PromptListChangedNotification, + ResourceListChangedNotification, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list-changed") +async def test_tool_list_changed_notification(connect: Connect) -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("install", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("resources:list-changed") +async def test_resource_list_changed_notification(connect: Connect) -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("prompts:list-changed") +async def test_prompt_list_changed_notification(connect: Connect) -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered so the prompts capability is advertised; the client never lists prompts.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..fba632ef4d --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,127 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds +only for messages that carry a ``related_request_id`` (they ride the originating request's POST +stream); without it the message routes to the standalone GET stream and may arrive after the +response. These tests pass ``related_request_id`` so they can collect into a plain list and +assert after the request completes on every transport leg -- no events, no waiting. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with connect(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message( + level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id + ) + await ctx.session.send_log_message( + level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message( + level=level, data=f"a {level} message", related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="logged")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a9e4f994d8 --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler(connect: Connect) -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client(connect: Connect) -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..77db90401e --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,242 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list:pagination") +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor=cursor, + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [tool.name for tool in first_page.tools] == ["alpha"] + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + + +@requirement("pagination:exhaustion") +@requirement("tools:list:pagination") +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("resources:list:pagination") +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with connect(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [resource.name for resource in first_page.resources] == ["first"] + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("resources:templates:pagination") +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor=cursor, + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [template.name for template in first_page.resource_templates] == ["first"] + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("prompts:list:pagination") +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: + """prompts/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with connect(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..797e20dc35 --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,53 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:ping") +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result(connect: Connect) -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with connect(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("lifecycle:ping") +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result(connect: Connect) -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..6350c33a33 --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,301 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and +over streamable HTTP only when sent with ``related_request_id`` so the notification rides the +originating request's POST stream rather than the standalone GET stream. These tests pass +``related_request_id`` so no synchronisation is needed. The client-to-server direction is a +standalone notification with no response to await, so that test waits on an event set by the +server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification( + token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) + ) + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("protocol:progress:token-injected") +async def test_progress_token_visible_to_handler(connect: Connect) -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with connect(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("protocol:progress:no-token") +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("protocol:progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with connect(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][0], related_request_id=str(ctx.request_id) + ) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][1], related_request_id=str(ctx.request_id) + ) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..868b82692c --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,209 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + EmbeddedResource, + ErrorData, + GetPromptResult, + Icon, + ImageContent, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with connect(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:with-args") +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..4e369d3645 --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,309 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + Annotations, + BlobResourceContents, + CallToolResult, + EmptyResult, + ErrorData, + Icon, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +@requirement("resources:annotations") +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + The SDK's Annotations model omits the schema's lastModified field (see the divergence on + resources:annotations); the input is built via model_validate with lastModified set so the + snapshot pins the drop and will fail once the SDK adds the field. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} + ), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text(connect: Connect) -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:blob") +async def test_read_resource_binary(connect: Connect) -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:unknown-uri") +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with connect(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with connect(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + ``send_resource_updated`` does not take a ``related_request_id``, so over streamable HTTP the + notification routes to the standalone GET stream and is not guaranteed to arrive before the + tool result; the test waits on an event the collector sets. The collector records every + message the handler receives, so the assertion also proves nothing else was delivered. + """ + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + server = Server( + "library", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..8149e0befb --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,166 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:basic") +async def test_list_roots_round_trip(connect: Connect) -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty(connect: Connect) -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..260e564192 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,687 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pydantic +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + AudioContent, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingCapability, + SamplingMessage, + TextContent, + ToolResultContent, + ToolUseContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") +async def test_create_message_round_trip(connect: Connect) -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") +async def test_create_message_params_reach_callback(connect: Connect) -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:error:user-rejected") +async def test_create_message_callback_error(connect: Connect) -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error(connect: Connect) -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:tools:server-gated-by-capability") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) + + +@requirement("sampling:tool-result:no-mixed-content") +async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: + """A sampling request whose user message mixes tool_result with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:server-preflight") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. The spec's + client-side -32602 check is tracked separately at sampling:tool-use:result-balance. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..a9c83d641d --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,114 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("protocol:timeout:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:timeout:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..e8053fbaa7 --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,512 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + +@requirement("tools:list:metadata") +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:mixed") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content(connect: Connect) -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) + + +@requirement("client:output-schema:validate") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..0f9c58aa7a --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,309 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. + +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListRootsResult, + TextContent, +) +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..26556fea7a --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,271 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import MCPError +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + Implementation, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +@requirement("logging:capability:declared") +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with connect(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + +@requirement("protocol:progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + +@requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with connect(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with connect(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) + + +@requirement("logging:message:filtered") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..2095f086d4 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,195 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompt:decorated") +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with connect(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompt:decorated") +async def test_get_prompt_renders_function_return(connect: Connect) -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with connect(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompt:unknown-name") +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error. + + The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on + the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-required-args") +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..57b0fdc86d --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,183 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resource:static") +async def test_read_static_resource(connect: Connect) -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with connect(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resource:static") +async def test_list_static_and_templated_resources(connect: Connect) -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with connect(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") +async def test_read_templated_resource(connect: Connect) -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with connect(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resource:unknown-uri") +async def test_read_unknown_uri_is_error(connect: Connect) -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..05135c1286 --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,432 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +import logging +from typing import Annotated, Literal + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, Field + +from mcp import MCPError +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitRequestURLParams, + ErrorData, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with connect(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + +@requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with connect(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tool:handler-throws") +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with connect(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tool:unknown-name") +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with connect(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tool:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("mcpserver:tool:input-validation") +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with connect(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:naming-validation") +async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_reject( + connect: Connect, caplog: pytest.LogCaptureFixture +) -> None: + """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. + + The intended behaviour is rejection at registration time; MCPServer instead logs the + naming-rule violation and proceeds (see the divergence note on the requirement). The warning + spans several SDK-authored log records, so only the stable prefix and inclusion of the + offending name are asserted. + """ + mcp = MCPServer("naming") + + with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): + + @mcp.tool(name="bad name!") + def bad() -> str: + return "ok" + + assert any( + rec.levelno == logging.WARNING + and rec.message.startswith("Tool name validation warning") + and "bad name!" in rec.message + for rec in caplog.records + ) + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("bad name!", {}) + + assert [tool.name for tool in listed.tools] == ["bad name!"] + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("mcpserver:register:post-connect") +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..7821c1eed5 --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,105 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +import re +from pathlib import Path +from types import ModuleType + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") + +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", + "tests.interaction.auth.test_flow.test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app", +} + + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..f78c6d14b5 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,169 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: + self._app = app + self._cancel_on_close = cancel_on_close + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + try: + await response_started.wait() + if application_error is not None: + raise application_error + except BaseException: + # No response will be built, so close the reader the response body would have owned + # and tell the application its peer has gone away. + client_disconnected.set() + await chunk_reader.aclose() + raise + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..5977cc3e99 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,63 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..7420b9d902 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,94 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with ( + httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http, + http.stream("GET", "/chunks") as response, + ): + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + with anyio.fail_after(5): + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..65ed03f1e4 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,247 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _), client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): # pragma: no branch + async with anyio.create_task_group() as tg: # pragma: no branch + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with ( + mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _), + client_via_http(http) as client, + ): + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller. The spec's MUST is tracked at + client-transport:http:session-404-reinitialize; this test pins the SDK's current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..c428fe2d68 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,129 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with ( + mounted_app(mcp) as (http, _), + connect_over_sse(mcp) as sse_client, + client_via_http(http) as shttp_client, + ): + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..85e64ded42 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,344 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + EmptyResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListResourcesResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + SubscribeRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + return Server( + "hosted", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:json-response-mode") +async def test_json_response_mode_answers_with_application_json_not_sse() -> None: + """With JSON response mode enabled, request POSTs are answered with a single application/json body. + + Asserted at the wire level because the SDK client parses either representation, so a + Client-driven round trip cannot distinguish a JSON response from an SSE one. + """ + async with mounted_app(_server(), json_response=True) as (http, _): + initialized = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + session_id = initialized.headers["mcp-session-id"] + ping = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, + headers=base_headers(session_id=session_id), + ) + + assert initialized.status_code == 200 + assert initialized.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(initialized.json()).id == 1 + assert ping.status_code == 200 + assert ping.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(ping.json()).id == 2 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..c7945d56c3 --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,372 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.message import ClientMessageMetadata +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + BASE_URL, + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( # pragma: no branch + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: # pragma: no branch + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +# This test intentionally carries every automatic-reconnection requirement: the +# close-then-resume scenario is indivisible, so splitting it would mean five near-identical bodies. +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("flow:resume:tool-call-resumption-token") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) + + +@requirement("client-transport:http:resume-stream-api") +async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: + """A resumption token captured via on_resumption_token_update on one connection lets a fresh + connection retrieve the messages it missed by passing resumption_token to send_request. + + This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the + previous test covers: the transport dispatches a resumption_token request as a GET with + Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new + request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a + bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client + cannot express. The second connection carries the original session id but does not initialize + (the server-side session already is), modelling a caller that resumes after a process restart. + """ + captured: list[str] = [] + received: list[object] = [] + first_seen = anyio.Event() + token_seen = anyio.Event() + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Emit one notification, wait for the test, emit another, return.""" + await ctx.info("first") + await release.wait() + await ctx.info("second") + return "done" + + async def on_token(token: str) -> None: + captured.append(token) + if len(captured) >= 2: + token_seen.set() + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + first_seen.set() + + call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + capture = ClientMessageMetadata(on_resumption_token_update=on_token) + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + with anyio.fail_after(5): # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + ClientSession(r1, w1, logging_callback=collect) as first, + anyio.create_task_group() as tg, + ): + await first.initialize() + tg.start_soon(first.send_request, call, CallToolResult, None, capture) + await first_seen.wait() + await token_seen.wait() + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + tg.cancel_scope.cancel() + + with anyio.fail_after(5): # pragma: no branch + release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + # init priming + init response + call priming + "first" + "second" + result = 6 stored events. + await store.wait_until_stored(6) + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + ClientSession(r2, w2, logging_callback=collect) as second, + ): + result = await second.send_request( + call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..a926c3e8a2 --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,202 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..9c7353dda5 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,90 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import BASE_URL, build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..27cc65de42 --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,143 @@ +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. + +Everything else in the suite runs in a single process; the subprocess test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. +""" + +import io +import json +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with ( + stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ), + read_stream, + write_stream, + ): + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..d38e2a0bb3 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,168 @@ +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, +) +from tests.interaction._connect import connect_over_streamable_http +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising each message shape the transport-specific tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@requirement("transport:streamable-http:json-response") +@requirement("client-transport:http:json-response-parsed") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:notifications") +@requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. + """ + received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() + + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:stateful") +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"]) From 680b736601347e2c605c2085ea255cdc7c3d2abe Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 13:32:03 +0000 Subject: [PATCH 02/19] =?UTF-8?q?backport:=20phase-1.5=20mechanical=20type?= =?UTF-8?q?s=20sweep=20(snake=E2=86=92camel=20kwargs,=20type=3D=20discrimi?= =?UTF-8?q?nators)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../interaction/auth/test_authorize_token.py | 2 +- tests/interaction/auth/test_discovery.py | 2 +- tests/interaction/auth/test_flow.py | 8 +- tests/interaction/auth/test_lifecycle.py | 2 +- .../interaction/lowlevel/test_cancellation.py | 24 +- tests/interaction/lowlevel/test_completion.py | 16 +- .../interaction/lowlevel/test_elicitation.py | 88 ++++---- tests/interaction/lowlevel/test_flows.py | 20 +- tests/interaction/lowlevel/test_initialize.py | 38 ++-- .../interaction/lowlevel/test_list_changed.py | 12 +- tests/interaction/lowlevel/test_logging.py | 10 +- tests/interaction/lowlevel/test_meta.py | 10 +- tests/interaction/lowlevel/test_pagination.py | 22 +- tests/interaction/lowlevel/test_ping.py | 6 +- tests/interaction/lowlevel/test_progress.py | 32 +-- tests/interaction/lowlevel/test_prompts.py | 46 ++-- tests/interaction/lowlevel/test_resources.py | 40 ++-- tests/interaction/lowlevel/test_roots.py | 30 ++- tests/interaction/lowlevel/test_sampling.py | 210 ++++++++++-------- tests/interaction/lowlevel/test_timeouts.py | 8 +- tests/interaction/lowlevel/test_tools.py | 120 +++++----- tests/interaction/lowlevel/test_wire.py | 4 +- tests/interaction/mcpserver/test_context.py | 18 +- tests/interaction/mcpserver/test_prompts.py | 6 +- tests/interaction/mcpserver/test_resources.py | 14 +- tests/interaction/mcpserver/test_tools.py | 42 ++-- tests/interaction/transports/_stdio_server.py | 4 +- .../transports/test_client_transport_http.py | 10 +- tests/interaction/transports/test_flows.py | 6 +- .../transports/test_hosting_http.py | 2 +- .../transports/test_hosting_resume.py | 6 +- .../transports/test_hosting_session.py | 2 +- tests/interaction/transports/test_stdio.py | 2 +- .../transports/test_streamable_http.py | 12 +- 34 files changed, 485 insertions(+), 389 deletions(-) diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index cb8524c097..8e44cffc04 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -51,7 +51,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) def authorize_params(authorize_url: str) -> dict[str, str]: diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py index 68c33c8a2d..afa3d0cd4b 100644 --- a/tests/interaction/auth/test_discovery.py +++ b/tests/interaction/auth/test_discovery.py @@ -43,7 +43,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="probe", inputSchema={"type": "object"})]) def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py index 968fc5f980..4c041cc112 100644 --- a/tests/interaction/auth/test_flow.py +++ b/tests/interaction/auth/test_flow.py @@ -40,7 +40,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="whoami", inputSchema={"type": "object"})]) @requirement("flow:oauth:authorization-code-roundtrip") @@ -76,7 +76,7 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow ): result = await client.list_tools() - assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", inputSchema={"type": "object"})])) assert headless.authorize_url is not None paths = [(r.method, r.url.path) for r in requests] @@ -126,7 +126,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert params.name == "whoami" token = get_access_token() assert token is not None - return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + return CallToolResult(content=[TextContent(type="text", text=" ".join(token.scopes))]) server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) provider = InMemoryAuthorizationServerProvider() @@ -135,7 +135,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect_with_oauth(server, provider=provider) as (client, _): result = await client.call_tool("whoami", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="mcp")])) @requirement("client-auth:pre-registration") diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py index aa552ae8a6..8812dccead 100644 --- a/tests/interaction/auth/test_lifecycle.py +++ b/tests/interaction/auth/test_lifecycle.py @@ -44,7 +44,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) def form_body(request: RecordedRequest) -> dict[str, str]: diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 6f1454e58a..9ba6797e0a 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -76,7 +76,7 @@ async def call_and_capture_error() -> None: await started.wait() await client.session.send_notification( types.CancelledNotification( - params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + params=types.CancelledNotificationParams(requestId=request_ids[0], reason="user aborted") ) ) @@ -96,14 +96,14 @@ async def list_tools( ) -> types.ListToolsResult: return types.ListToolsResult( tools=[ - types.Tool(name="block", input_schema={"type": "object"}), - types.Tool(name="echo", input_schema={"type": "object"}), + types.Tool(name="block", inputSchema={"type": "object"}), + types.Tool(name="echo", inputSchema={"type": "object"}), ] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: if params.name == "echo": - return CallToolResult(content=[TextContent(text="still alive")]) + return CallToolResult(content=[TextContent(type="text", text="still alive")]) assert ctx.request_id is not None request_ids.append(ctx.request_id) started.set() @@ -123,12 +123,12 @@ async def call_and_swallow_cancellation_error() -> None: task_group.start_soon(call_and_swallow_cancellation_error) await started.wait() await client.session.send_notification( - types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + types.CancelledNotification(params=types.CancelledNotificationParams(requestId=request_ids[0])) ) result = await client.call_tool("echo", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="still alive")])) @requirement("protocol:cancel:unknown-id-ignored") @@ -138,21 +138,21 @@ async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="echo", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "echo" - return CallToolResult(content=[TextContent(text="unbothered")]) + return CallToolResult(content=[TextContent(type="text", text="unbothered")]) server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: await client.session.send_notification( - types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + types.CancelledNotification(params=types.CancelledNotificationParams(requestId=9999)) ) result = await client.call_tool("echo", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="unbothered")])) @requirement("protocol:cancel:late-response-ignored") @@ -192,9 +192,9 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage respond( init.message.id, InitializeResult( - protocol_version="2025-11-25", + protocolVersion="2025-11-25", capabilities=ServerCapabilities(), - server_info=Implementation(name="scripted", version="0.0.1"), + serverInfo=Implementation(name="scripted", version="0.0.1"), ), ) ) diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index 6a35404df3..42b059f742 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -34,17 +34,17 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar assert params.argument.name == "language" candidates = ["python", "pytorch", "ruby"] matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] - return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + return CompleteResult(completion=Completion(values=matches, total=len(matches), hasMore=False)) server = Server("completer", on_completion=completion) async with connect(server) as client: result = await client.complete( - PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + PromptReference(type="ref/prompt", name="code_review"), argument={"name": "language", "value": "py"} ) assert result == snapshot( - CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, hasMore=False)) ) @@ -62,7 +62,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar async with connect(server) as client: result = await client.complete( - ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), argument={"name": "owner", "value": "model"}, ) @@ -86,7 +86,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar async with connect(server) as client: result = await client.complete( - ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), argument={"name": "repo", "value": ""}, context_arguments={"owner": "modelcontextprotocol"}, ) @@ -111,7 +111,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar async with connect(server) as client: with pytest.raises(MCPError) as exc_info: - await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + await client.complete(PromptReference(type="ref/prompt", name="ghost"), argument={"name": "x", "value": ""}) assert exc_info.value.error.code == INVALID_PARAMS @@ -126,6 +126,8 @@ async def test_complete_without_handler_is_method_not_found(connect: Connect) -> assert client.initialize_result.capabilities.completions is None with pytest.raises(MCPError) as exc_info: - await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + await client.complete( + PromptReference(type="ref/prompt", name="anything"), argument={"name": "topic", "value": ""} + ) assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index b8edf601d0..82c81aa24b 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -62,13 +62,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + tools=[types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "signup" answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) - return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + return CallToolResult(content=[TextContent(type="text", text=answer.action)], structuredContent=answer.content) server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) @@ -84,7 +84,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest ElicitRequestFormParams( _meta={}, message="Choose a username.", - requested_schema={ + requestedSchema={ "type": "object", "properties": { "username": {"type": "string"}, @@ -97,8 +97,8 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest ) assert result == snapshot( CallToolResult( - content=[TextContent(text="accept")], - structured_content={"username": "ada", "newsletter": True}, + content=[TextContent(type="text", text="accept")], + structuredContent={"username": "ada", "newsletter": True}, ) ) @@ -111,13 +111,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + tools=[types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "confirm" answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) - return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -127,7 +127,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("confirm", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="decline content=None")])) @requirement("elicitation:form:action:cancel") @@ -138,13 +138,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + tools=[types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "confirm" answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) - return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -154,7 +154,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("confirm", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="cancel content=None")])) @requirement("elicitation:form:not-supported") @@ -173,7 +173,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + tools=[types.Tool(name="ask", description="Ask the user.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -181,7 +181,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara try: await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # elicit_form cannot succeed without a client callback server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) @@ -189,7 +189,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: result = await client.call_tool("ask", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="-32600: Elicitation not supported")]) + ) @requirement("elicitation:url:action:accept-no-content") @@ -207,7 +209,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -215,7 +217,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara answer = await ctx.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) - return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -232,11 +234,11 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP _meta={}, message="Authorize access to your calendar.", url="https://example.com/oauth/authorize", - elicitation_id="auth-001", + elicitationId="auth-001", ) ] ) - assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="accept content=None")])) @requirement("elicitation:url:decline") @@ -247,7 +249,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -255,7 +257,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara answer = await ctx.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) - return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -265,7 +267,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="decline content=None")])) @requirement("elicitation:url:cancel") @@ -276,7 +278,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -284,7 +286,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara answer = await ctx.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) - return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -294,7 +296,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="cancel content=None")])) @requirement("elicitation:url:complete-notification") @@ -319,7 +321,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + tools=[types.Tool(name="link_account", description="Link an account.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -329,7 +331,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) assert answer.action == "accept" await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) - return CallToolResult(content=[TextContent(text="linked")]) + return CallToolResult(content=[TextContent(type="text", text="linked")]) server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -344,7 +346,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP # The completion notification refers to the same elicitation the client accepted. assert elicited_ids == [elicitation_id] assert received == snapshot( - [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitationId="auth-001"))] ) @@ -365,7 +367,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ElicitRequestURLParams( message="Authorization required for your files.", url="https://example.com/oauth/authorize", - elicitation_id="auth-001", + elicitationId="auth-001", ) ] ) @@ -429,13 +431,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + tools=[types.Tool(name="onboard", description="Onboard the user.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "onboard" answer = await ctx.session.elicit_form("Tell us about yourself.", schema) - return CallToolResult(content=[TextContent(text=answer.action)]) + return CallToolResult(content=[TextContent(type="text", text=answer.action)]) server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) @@ -479,13 +481,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + tools=[types.Tool(name="profile", description="Collect a profile.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "profile" answer = await ctx.session.elicit_form("Profile details.", schema) - return CallToolResult(content=[TextContent(text=answer.action)]) + return CallToolResult(content=[TextContent(type="text", text=answer.action)]) server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -518,7 +520,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + tools=[types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -527,7 +529,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara "Choose a name.", {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, ) - return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + return CallToolResult(content=[TextContent(type="text", text=answer.action)], structuredContent=answer.content) server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) @@ -538,7 +540,9 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest result = await client.call_tool("signup", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + CallToolResult( + content=[TextContent(type="text", text="accept")], structuredContent={"name": 42, "extra": "field"} + ) ) @@ -555,13 +559,13 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + tools=[types.Tool(name="noop", description="Send a stray complete.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "noop" await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) - return CallToolResult(content=[TextContent(text="ok")]) + return CallToolResult(content=[TextContent(type="text", text="ok")]) server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) @@ -573,9 +577,9 @@ async def collect(message: IncomingMessage) -> None: async with connect(server, message_handler=collect) as client: result = await client.call_tool("noop", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ok")])) assert received == snapshot( - [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitationId="never-elicited"))] ) @@ -603,9 +607,9 @@ async def scripted_server(streams: MessageStream) -> None: assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( - protocol_version="2025-11-25", + protocolVersion="2025-11-25", capabilities=ServerCapabilities(), - server_info=Implementation(name="legacy", version="0.0.1"), + serverInfo=Implementation(name="legacy", version="0.0.1"), ) await server_write.send( SessionMessage( @@ -651,7 +655,7 @@ async def scripted_server(streams: MessageStream) -> None: ElicitRequestFormParams( _meta=None, message="Legacy ask.", - requested_schema={"type": "object", "properties": {}}, + requestedSchema={"type": "object", "properties": {}}, ) ] ) diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py index 8d96582341..e3397e2f4b 100644 --- a/tests/interaction/lowlevel/test_flows.py +++ b/tests/interaction/lowlevel/test_flows.py @@ -46,7 +46,7 @@ def _list_tools(*names: str) -> ListToolsHandler: """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + return ListToolsResult(tools=[Tool(name=name, inputSchema={"type": "object"}) for name in names]) return list_tools @@ -61,7 +61,7 @@ async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(conn async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "generate" - return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + return CallToolResult(content=[ResourceLink(type="resource_link", uri="file:///report.txt", name="report")]) async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: assert str(params.uri) == "file:///report.txt" @@ -77,7 +77,9 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq assert isinstance(link, ResourceLink) read = await client.read_resource(link.uri) - assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert called == snapshot( + CallToolResult(content=[ResourceLink(type="resource_link", name="report", uri="file:///report.txt")]) + ) assert read == snapshot( ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) ) @@ -108,7 +110,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara {"type": "object", "properties": {"age": {"type": "integer"}}}, ) assert second.action == "accept" and second.content is not None - return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + return CallToolResult( + content=[TextContent(type="text", text=f"{first.content['name']} is {second.content['age']}")] + ) server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) @@ -120,7 +124,7 @@ async def answer(context: ClientRequestContext, params: types.ElicitRequestParam async with connect(server, elicitation_callback=answer) as client: result = await client.call_tool("onboard", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ada is 37")])) assert [(p.message, p.requested_schema) for p in received] == snapshot( [ ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), @@ -162,11 +166,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ElicitRequestURLParams( message="Authorize file access.", url="https://example.com/oauth/authorize", - elicitation_id=elicitation_id, + elicitationId=elicitation_id, ) ] ) - return CallToolResult(content=[TextContent(text="contents")]) + return CallToolResult(content=[TextContent(type="text", text="contents")]) async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: """Registered so the logging capability is advertised; the client never sets a level.""" @@ -200,4 +204,4 @@ async def collect(message: IncomingMessage) -> None: authorised[0] = True result = await client.call_tool("read_files", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 91adbf5611..b81471e5b9 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -56,7 +56,7 @@ async def test_initialize_returns_server_info(connect: Connect) -> None: title="Greeter", description="Greets people.", website_url="https://example.com/greeter", - icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], ) async with connect(server) as client: @@ -68,8 +68,8 @@ async def test_initialize_returns_server_info(connect: Connect) -> None: title="Greeter", description="Greets people.", version="1.2.3", - website_url="https://example.com/greeter", - icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + websiteUrl="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], ) ) @@ -143,9 +143,9 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar ServerCapabilities( experimental={}, logging=LoggingCapability(), - prompts=PromptsCapability(list_changed=False), - resources=ResourcesCapability(subscribe=True, list_changed=False), - tools=ToolsCapability(list_changed=False), + prompts=PromptsCapability(listChanged=False), + resources=ResourcesCapability(subscribe=True, listChanged=False), + tools=ToolsCapability(listChanged=False), completions=CompletionsCapability(), ) ) @@ -168,20 +168,20 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + tools=[types.Tool(name="whoami", description="Report the caller.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "whoami" assert ctx.session.client_params is not None client_info = ctx.session.client_params.client_info - return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{client_info.name} {client_info.version}")]) server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: result = await client.call_tool("whoami", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="acme-agent 9.9.9")])) @requirement("lifecycle:initialize:client-capabilities") @@ -192,7 +192,7 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + tools=[types.Tool(name="abilities", description="Report capabilities.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -209,7 +209,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ] if capabilities.roots is not None: declared.append(f"roots(list_changed={capabilities.roots.list_changed})") - return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + return CallToolResult(content=[TextContent(type="text", text=",".join(declared) or "none")]) async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: """Registered only so the client declares the roots capability; never called.""" @@ -219,11 +219,11 @@ async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: async with connect(server) as client: result = await client.call_tool("abilities", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="none")])) async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("abilities", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="roots(list_changed=True)")])) @requirement("lifecycle:requests-before-initialized") @@ -271,9 +271,9 @@ async def test_initialize_negotiates_protocol_version() -> None: def initialize_request(protocol_version: str) -> InitializeRequest: return InitializeRequest( params=InitializeRequestParams( - protocol_version=protocol_version, + protocolVersion=protocol_version, capabilities=ClientCapabilities(), - client_info=Implementation(name="time-traveller", version="0.0.1"), + clientInfo=Implementation(name="time-traveller", version="0.0.1"), ) ) @@ -312,9 +312,9 @@ async def scripted_server(streams: MessageStream) -> None: assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( - protocol_version="1991-08-06", + protocolVersion="1991-08-06", capabilities=ServerCapabilities(), - server_info=Implementation(name="relic", version="0.0.1"), + serverInfo=Implementation(name="relic", version="0.0.1"), ) await server_write.send( SessionMessage( @@ -357,9 +357,9 @@ async def scripted_server(streams: MessageStream) -> None: assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( - protocol_version="2025-06-18", + protocolVersion="2025-06-18", capabilities=ServerCapabilities(), - server_info=Implementation(name="conservative", version="0.0.1"), + serverInfo=Implementation(name="conservative", version="0.0.1"), ) await server_write.send( SessionMessage( diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index a2f85eeacf..c0449fea9b 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -47,12 +47,12 @@ async def collect(message: IncomingMessage) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="install", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "install" await ctx.session.send_tool_list_changed() - return CallToolResult(content=[TextContent(text="installed")]) + return CallToolResult(content=[TextContent(type="text", text="installed")]) server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) @@ -77,12 +77,12 @@ async def collect(message: IncomingMessage) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="mount", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "mount" await ctx.session.send_resource_list_changed() - return CallToolResult(content=[TextContent(text="mounted")]) + return CallToolResult(content=[TextContent(type="text", text="mounted")]) async def list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -113,12 +113,12 @@ async def collect(message: IncomingMessage) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="learn", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "learn" await ctx.session.send_prompt_list_changed() - return CallToolResult(content=[TextContent(text="learned")]) + return CallToolResult(content=[TextContent(type="text", text="learned")]) async def list_prompts( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index fba632ef4d..070ee49c82 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -64,7 +64,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="chatty", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "chatty" @@ -74,7 +74,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_log_message( level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id ) - return CallToolResult(content=[TextContent(text="done")]) + return CallToolResult(content=[TextContent(type="text", text="done")]) async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: """Registered so the logging capability is advertised; the client never sets a level.""" @@ -85,7 +85,7 @@ async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelReq async with connect(server, logging_callback=collect) as client: result = await client.call_tool("chatty", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="done")])) assert received == snapshot( [ LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), @@ -105,7 +105,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="siren", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "siren" @@ -113,7 +113,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_log_message( level=level, data=f"a {level} message", related_request_id=ctx.request_id ) - return CallToolResult(content=[TextContent(text="logged")]) + return CallToolResult(content=[TextContent(type="text", text="logged")]) async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: """Registered so the logging capability is advertised; the client never sets a level.""" diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py index a9e4f994d8..821beeebae 100644 --- a/tests/interaction/lowlevel/test_meta.py +++ b/tests/interaction/lowlevel/test_meta.py @@ -25,13 +25,13 @@ async def test_request_meta_reaches_handler(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="traced", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "traced" assert ctx.meta is not None observed_metas.append(dict(ctx.meta)) - return CallToolResult(content=[TextContent(text="traced")]) + return CallToolResult(content=[TextContent(type="text", text="traced")]) server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) @@ -49,15 +49,15 @@ async def test_result_meta_reaches_client(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="metered", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "metered" - return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + return CallToolResult(content=[TextContent(type="text", text="done")], _meta=result_meta) server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("metered", {}) - assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + assert result == CallToolResult(content=[TextContent(type="text", text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 77db90401e..e2c2ba2612 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -40,10 +40,10 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa seen_cursors.append(params.cursor) if params.cursor is None: return ListToolsResult( - tools=[Tool(name="alpha", input_schema={"type": "object"})], - next_cursor=cursor, + tools=[Tool(name="alpha", inputSchema={"type": "object"})], + nextCursor=cursor, ) - return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="beta", inputSchema={"type": "object"})]) server = Server("paginated", on_list_tools=list_tools) @@ -54,7 +54,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa assert first_page.next_cursor == cursor assert seen_cursors == [None, cursor] assert [tool.name for tool in first_page.tools] == ["alpha"] - assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", inputSchema={"type": "object"})])) @requirement("pagination:exhaustion") @@ -70,7 +70,7 @@ async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: assert params is not None tool_name, next_cursor = pages[params.cursor] - return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + return ListToolsResult(tools=[Tool(name=tool_name, inputSchema={"type": "object"})], nextCursor=next_cursor) server = Server("paginated", on_list_tools=list_tools) @@ -113,7 +113,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa received_cursors.append(params.cursor) names, next_cursor = pages[params.cursor] return ListToolsResult( - tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + tools=[Tool(name=name, inputSchema={"type": "object"}) for name in names], nextCursor=next_cursor ) server = Server("paginated", on_list_tools=list_tools) @@ -167,7 +167,7 @@ async def list_resources( assert params is not None seen_cursors.append(params.cursor) if params.cursor is None: - return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], nextCursor=cursor) return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) server = Server("paginated", on_list_resources=list_resources) @@ -196,11 +196,11 @@ async def list_resource_templates( seen_cursors.append(params.cursor) if params.cursor is None: return ListResourceTemplatesResult( - resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], - next_cursor=cursor, + resourceTemplates=[ResourceTemplate(name="first", uriTemplate="users://{id}")], + nextCursor=cursor, ) return ListResourceTemplatesResult( - resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + resourceTemplates=[ResourceTemplate(name="second", uriTemplate="teams://{id}")] ) server = Server("paginated", on_list_resource_templates=list_resource_templates) @@ -226,7 +226,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest assert params is not None seen_cursors.append(params.cursor) if params.cursor is None: - return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) + return ListPromptsResult(prompts=[Prompt(name="first")], nextCursor=cursor) return ListPromptsResult(prompts=[Prompt(name="second")]) server = Server("paginated", on_list_prompts=list_prompts) diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py index 797e20dc35..ed1f466809 100644 --- a/tests/interaction/lowlevel/test_ping.py +++ b/tests/interaction/lowlevel/test_ping.py @@ -37,17 +37,17 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + tools=[types.Tool(name="ping_back", description="Ping the client.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ping_back" pong = await ctx.session.send_ping() - return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + return CallToolResult(content=[TextContent(type="text", text=type(pong).__name__)]) server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("ping_back", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 6350c33a33..183afa6098 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -37,7 +37,7 @@ async def collect(progress: float, total: float | None, message: str | None) -> async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="download", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "download" @@ -53,14 +53,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_progress_notification( token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) ) - return CallToolResult(content=[TextContent(text="downloaded")]) + return CallToolResult(content=[TextContent(type="text", text="downloaded")]) server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("download", {}, progress_callback=collect) - assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="downloaded")])) assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) @@ -71,12 +71,12 @@ async def test_progress_token_visible_to_handler(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="inspect", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "inspect" assert ctx.meta is not None - return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + return CallToolResult(content=[TextContent(type="text", text=str(ctx.meta.get("progress_token")))]) server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) @@ -88,7 +88,7 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N result = await client.call_tool("inspect", {}, progress_callback=ignore) # The token is the request id of the tools/call request itself (initialize is request 0). - assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="1")])) @requirement("protocol:progress:no-token") @@ -102,19 +102,19 @@ async def test_no_progress_callback_means_no_token(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="inspect", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "inspect" assert ctx.meta is not None - return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + return CallToolResult(content=[TextContent(type="text", text=str(ctx.meta.get("progress_token")))]) server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("inspect", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="None")])) @requirement("protocol:progress:client-to-server") @@ -135,7 +135,7 @@ async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationPar await delivered.wait() assert received == snapshot( - [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + [ProgressNotificationParams(progressToken="upload-1", progress=0.5, total=1.0, message="halfway")] ) @@ -160,7 +160,7 @@ async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Conne async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="report", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "report" @@ -184,7 +184,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) if second + 1 < len(turns): turns[second + 1].set() - return CallToolResult(content=[TextContent(text="done")]) + return CallToolResult(content=[TextContent(type="text", text="done")]) server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) @@ -231,7 +231,7 @@ async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="report", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "report" @@ -240,7 +240,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert token is not None captured.append((ctx.session, token)) await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) - return CallToolResult(content=[TextContent(text="done")]) + return CallToolResult(content=[TextContent(type="text", text="done")]) server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) @@ -281,7 +281,7 @@ async def collect(progress: float, total: float | None, message: str | None) -> async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="zigzag", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "zigzag" @@ -291,7 +291,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) - return CallToolResult(content=[TextContent(text="done")]) + return CallToolResult(content=[TextContent(type="text", text="done")]) server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index 868b82692c..50da0e75f2 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -40,7 +40,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest PromptArgument(name="code", description="The code to review.", required=True), PromptArgument(name="style_guide", description="Optional style guide to apply."), ], - icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + icons=[Icon(src="https://example.com/review.png", mimeType="image/png", sizes=["48x48"])], ), Prompt(name="daily_standup"), ] @@ -61,7 +61,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest PromptArgument(name="code", description="The code to review.", required=True), PromptArgument(name="style_guide", description="Optional style guide to apply."), ], - icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + icons=[Icon(src="https://example.com/review.png", mimeType="image/png", sizes=["48x48"])], ), Prompt(name="daily_standup"), ] @@ -78,7 +78,9 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert params.arguments is not None return GetPromptResult( description="A personalised greeting.", - messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + messages=[ + PromptMessage(role="user", content=TextContent(type="text", text=f"Hello, {params.arguments['name']}!")) + ], ) server = Server("prompter", on_get_prompt=get_prompt) @@ -89,7 +91,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert result == snapshot( GetPromptResult( description="A personalised greeting.", - messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + messages=[PromptMessage(role="user", content=TextContent(type="text", text="Hello, Ada!"))], ) ) @@ -102,9 +104,11 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert params.name == "geography_quiz" return GetPromptResult( messages=[ - PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), - PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), - PromptMessage(role="user", content=TextContent(text="And of Italy?")), + PromptMessage(role="user", content=TextContent(type="text", text="What is the capital of France?")), + PromptMessage( + role="assistant", content=TextContent(type="text", text="The capital of France is Paris.") + ), + PromptMessage(role="user", content=TextContent(type="text", text="And of Italy?")), ] ) @@ -116,9 +120,11 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert result == snapshot( GetPromptResult( messages=[ - PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), - PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), - PromptMessage(role="user", content=TextContent(text="And of Italy?")), + PromptMessage(role="user", content=TextContent(type="text", text="What is the capital of France?")), + PromptMessage( + role="assistant", content=TextContent(type="text", text="The capital of France is Paris.") + ), + PromptMessage(role="user", content=TextContent(type="text", text="And of Italy?")), ] ) ) @@ -131,7 +137,9 @@ async def test_get_prompt_without_arguments_returns_the_messages(connect: Connec async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: assert params.name == "static" assert params.arguments is None - return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + return GetPromptResult( + messages=[PromptMessage(role="user", content=TextContent(type="text", text="Say hello."))] + ) server = Server("prompter", on_get_prompt=get_prompt) @@ -139,7 +147,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa result = await client.get_prompt("static") assert result == snapshot( - GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(type="text", text="Say hello."))]) ) @@ -158,12 +166,13 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert params.name == "media" return GetPromptResult( messages=[ - PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), - PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")), + PromptMessage(role="assistant", content=AudioContent(type="audio", data="YXVk", mimeType="audio/wav")), PromptMessage( role="user", content=EmbeddedResource( - resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + type="resource", + resource=TextResourceContents(uri="resource://notes/1", mimeType="text/plain", text="attached"), ), ), ] @@ -177,12 +186,13 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa assert result == snapshot( GetPromptResult( messages=[ - PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), - PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")), + PromptMessage(role="assistant", content=AudioContent(type="audio", data="YXVk", mimeType="audio/wav")), PromptMessage( role="user", content=EmbeddedResource( - resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + type="resource", + resource=TextResourceContents(uri="resource://notes/1", mimeType="text/plain", text="attached"), ), ), ] diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 4e369d3645..8a525898d5 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -55,12 +55,12 @@ async def list_resources( name="readme", title="Project README", description="The project's front page.", - mime_type="text/markdown", + mimeType="text/markdown", size=1024, annotations=Annotations.model_validate( {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} ), - icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + icons=[Icon(src="https://example.com/readme.png", mimeType="image/png", sizes=["48x48"])], ), ] ) @@ -79,10 +79,10 @@ async def list_resources( name="readme", title="Project README", description="The project's front page.", - mime_type="text/markdown", + mimeType="text/markdown", size=1024, annotations=Annotations(audience=["user", "assistant"], priority=0.8), - icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + icons=[Icon(src="https://example.com/readme.png", mimeType="image/png", sizes=["48x48"])], ), ] ) @@ -95,7 +95,7 @@ async def test_read_resource_text(connect: Connect) -> None: async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: return ReadResourceResult( - contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + contents=[TextResourceContents(uri=params.uri, mimeType="text/plain", text="Hello, world!")] ) server = Server("library", on_read_resource=read_resource) @@ -105,7 +105,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq assert result == snapshot( ReadResourceResult( - contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + contents=[TextResourceContents(uri="file:///greeting.txt", mimeType="text/plain", text="Hello, world!")] ) ) @@ -119,7 +119,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq contents=[ BlobResourceContents( uri=params.uri, - mime_type="image/png", + mimeType="image/png", blob=base64.b64encode(b"\x89PNG").decode(), ) ] @@ -132,7 +132,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq assert result == snapshot( ReadResourceResult( - contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + contents=[BlobResourceContents(uri="file:///pixel.png", mimeType="image/png", blob="iVBORw==")] ) ) @@ -165,15 +165,15 @@ async def list_resource_templates( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> ListResourceTemplatesResult: return ListResourceTemplatesResult( - resource_templates=[ - ResourceTemplate(uri_template="users://{user_id}", name="user"), + resourceTemplates=[ + ResourceTemplate(uriTemplate="users://{user_id}", name="user"), ResourceTemplate( - uri_template="logs://{service}/{date}", + uriTemplate="logs://{service}/{date}", name="service_logs", title="Service logs", description="One day of logs for one service.", - mime_type="text/plain", - icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + mimeType="text/plain", + icons=[Icon(src="https://example.com/logs.png", mimeType="image/png", sizes=["48x48"])], ), ] ) @@ -185,15 +185,15 @@ async def list_resource_templates( assert result == snapshot( ListResourceTemplatesResult( - resource_templates=[ - ResourceTemplate(uri_template="users://{user_id}", name="user"), + resourceTemplates=[ + ResourceTemplate(uriTemplate="users://{user_id}", name="user"), ResourceTemplate( - uri_template="logs://{service}/{date}", + uriTemplate="logs://{service}/{date}", name="service_logs", title="Service logs", description="One day of logs for one service.", - mime_type="text/plain", - icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + mimeType="text/plain", + icons=[Icon(src="https://example.com/logs.png", mimeType="image/png", sizes=["48x48"])], ), ] ) @@ -274,12 +274,12 @@ async def collect(message: IncomingMessage) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="touch", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "touch" await ctx.session.send_resource_updated("file:///watched.txt") - return CallToolResult(content=[TextContent(text="touched")]) + return CallToolResult(content=[TextContent(type="text", text="touched")]) async def list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 8149e0befb..5bebd9158d 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -25,13 +25,13 @@ async def test_list_roots_round_trip(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "show_roots" result = await ctx.session.list_roots() lines = [f"{root.uri} name={root.name}" for root in result.roots] - return CallToolResult(content=[TextContent(text="\n".join(lines))]) + return CallToolResult(content=[TextContent(type="text", text="\n".join(lines))]) server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) @@ -48,7 +48,11 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: assert result == snapshot( CallToolResult( - content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + content=[ + TextContent( + type="text", text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None" + ) + ] ) ) @@ -60,12 +64,12 @@ async def test_list_roots_empty(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="count_roots", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "count_roots" result = await ctx.session.list_roots() - return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + return CallToolResult(content=[TextContent(type="text", text=str(len(result.roots)))]) server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) @@ -75,7 +79,7 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("count_roots", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="0")])) @requirement("roots:list:not-supported") @@ -89,14 +93,14 @@ async def test_list_roots_without_callback_is_error(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "show_roots" try: await ctx.session.list_roots() except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # list_roots cannot succeed without a client callback server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) @@ -104,7 +108,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: result = await client.call_tool("show_roots", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="-32600: List roots not supported")]) + ) @requirement("roots:list:client-error") @@ -117,14 +123,14 @@ async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connec async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "show_roots" try: await ctx.session.list_roots() except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # the callback always answers with an error server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) @@ -135,7 +141,7 @@ async def list_roots(context: ClientRequestContext) -> ErrorData: async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("show_roots", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="-32603: roots provider crashed")])) @requirement("roots:list-changed") diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 260e564192..3ee68f5d9e 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -45,16 +45,18 @@ async def test_create_message_round_trip(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + return CallToolResult( + content=[TextContent(type="text", text=f"{result.model}/{result.stop_reason}: {result.content.text}")] + ) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -64,21 +66,23 @@ async def sampling_callback( received.append(params) return CreateMessageResult( role="assistant", - content=TextContent(text="Hello to you too."), + content=TextContent(type="text", text="Hello to you too."), model="mock-llm-1", - stop_reason="endTurn", + stopReason="endTurn", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="mock-llm-1/endTurn: Hello to you too.")]) + ) assert received == snapshot( [ CreateMessageRequestParams( _meta={}, - messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], - max_tokens=100, + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], + maxTokens=100, ) ] ) @@ -100,12 +104,12 @@ async def test_create_message_params_reach_callback(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Pick a model."))], max_tokens=50, system_prompt="You are terse.", include_context="thisServer", @@ -113,13 +117,13 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara stop_sequences=["\n\n", "END"], model_preferences=ModelPreferences( hints=[ModelHint(name="claude"), ModelHint(name="gpt")], - cost_priority=0.2, - speed_priority=0.3, - intelligence_priority=0.9, + costPriority=0.2, + speedPriority=0.3, + intelligencePriority=0.9, ), ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(text=result.content.text)]) + return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -127,28 +131,28 @@ async def sampling_callback( context: ClientRequestContext, params: CreateMessageRequestParams ) -> CreateMessageResult: received.append(params) - return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + return CreateMessageResult(role="assistant", content=TextContent(type="text", text="ok"), model="mock-llm-1") async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ok")])) assert received == snapshot( [ CreateMessageRequestParams( _meta={}, - messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], - model_preferences=ModelPreferences( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Pick a model."))], + modelPreferences=ModelPreferences( hints=[ModelHint(name="claude"), ModelHint(name="gpt")], - cost_priority=0.2, - speed_priority=0.3, - intelligence_priority=0.9, + costPriority=0.2, + speedPriority=0.3, + intelligencePriority=0.9, ), - system_prompt="You are terse.", - include_context="thisServer", + systemPrompt="You are terse.", + includeContext="thisServer", temperature=0.7, - max_tokens=50, - stop_sequences=["\n\n", "END"], + maxTokens=50, + stopSequences=["\n\n", "END"], ) ] ) @@ -166,16 +170,18 @@ async def test_create_message_request_with_image_content_reaches_callback(connec async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="describe_image", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "describe_image" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + messages=[ + SamplingMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")) + ], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(text=result.content.text)]) + return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -187,20 +193,22 @@ async def sampling_callback( assert isinstance(image, ImageContent) return CreateMessageResult( role="assistant", - content=TextContent(text=f"described {image.mime_type} ({image.data})"), + content=TextContent(type="text", text=f"described {image.mime_type} ({image.data})"), model="mock-vision-1", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("describe_image", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="described image/png (aW1n)")])) assert received == snapshot( [ CreateMessageRequestParams( _meta={}, - messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], - max_tokens=100, + messages=[ + SamplingMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")) + ], + maxTokens=100, ) ] ) @@ -216,17 +224,19 @@ async def test_create_message_result_with_image_content_returns_to_handler(conne async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="draw", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "draw" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Draw a cat."))], max_tokens=100, ) image = result.content assert isinstance(image, ImageContent) - return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + return CallToolResult( + content=[TextContent(type="text", text=f"{result.model}: {image.mime_type} {image.data}")] + ) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -235,14 +245,14 @@ async def sampling_callback( ) -> CreateMessageResult: return CreateMessageResult( role="assistant", - content=ImageContent(data="Y2F0", mime_type="image/png"), + content=ImageContent(type="image", data="Y2F0", mimeType="image/png"), model="mock-vision-1", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("draw", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="mock-vision-1: image/png Y2F0")])) @requirement("sampling:error:user-rejected") @@ -256,17 +266,17 @@ async def test_create_message_callback_error(connect: Connect) -> None: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" try: await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # the callback always answers with an error server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -277,7 +287,9 @@ async def sampling_callback(context: ClientRequestContext, params: CreateMessage async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="-1: User rejected sampling request")]) + ) @requirement("sampling:create-message:not-supported") @@ -287,17 +299,17 @@ async def test_create_message_without_callback_is_error(connect: Connect) -> Non async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" try: await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # create_message cannot succeed without a client callback server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -305,7 +317,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: result = await client.call_tool("ask_model", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="-32600: Sampling not supported")])) @requirement("sampling:tools:server-gated-by-capability") @@ -319,18 +331,18 @@ async def test_create_message_with_tools_is_rejected_for_unsupporting_client(con async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" try: await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="What is the weather?"))], max_tokens=100, - tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + tools=[types.Tool(name="get_weather", inputSchema={"type": "object"})], ) except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # the validator rejects every tool-enabled request server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -345,7 +357,9 @@ async def sampling_callback( result = await client.call_tool("ask_model", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + CallToolResult( + content=[TextContent(type="text", text="-32602: Client does not support sampling tools capability")] + ) ) @@ -361,7 +375,7 @@ async def test_create_message_with_mixed_tool_result_content_is_rejected(connect async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "summarise_tools" @@ -371,15 +385,17 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara SamplingMessage( role="user", content=[ - ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), - TextContent(text="Also, a comment alongside the result."), + ToolResultContent( + type="tool_result", toolUseId="call-1", content=[TextContent(type="text", text="42")] + ), + TextContent(type="text", text="Also, a comment alongside the result."), ], ) ], max_tokens=100, ) except ValueError as exc: - return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")]) raise NotImplementedError # the validator rejects the malformed messages before sending server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -396,7 +412,10 @@ async def sampling_callback( assert result == snapshot( CallToolResult( content=[ - TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + TextContent( + type="text", + text="ValueError: The last message must contain only tool_result content if any is present", + ) ] ) ) @@ -414,13 +433,13 @@ async def test_a_client_with_a_sampling_callback_declares_the_sampling_capabilit async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="capabilities", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "capabilities" assert ctx.session.client_params is not None captured.append(ctx.session.client_params.capabilities.sampling) - return CallToolResult(content=[TextContent(text="ok")]) + return CallToolResult(content=[TextContent(type="text", text="ok")]) server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) @@ -448,16 +467,18 @@ async def test_create_message_request_with_audio_content_reaches_callback(connec async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="transcribe", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "transcribe" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + messages=[ + SamplingMessage(role="user", content=AudioContent(type="audio", data="c25k", mimeType="audio/wav")) + ], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(text=result.content.text)]) + return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -469,20 +490,22 @@ async def sampling_callback( assert isinstance(audio, AudioContent) return CreateMessageResult( role="assistant", - content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + content=TextContent(type="text", text=f"transcribed {audio.mime_type} ({audio.data})"), model="mock-audio-1", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("transcribe", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="transcribed audio/wav (c25k)")])) assert received == snapshot( [ CreateMessageRequestParams( _meta={}, - messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], - max_tokens=100, + messages=[ + SamplingMessage(role="user", content=AudioContent(type="audio", data="c25k", mimeType="audio/wav")) + ], + maxTokens=100, ) ] ) @@ -498,17 +521,19 @@ async def test_create_message_result_with_audio_content_returns_to_handler(conne async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="speak", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "speak" result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello, aloud."))], max_tokens=100, ) audio = result.content assert isinstance(audio, AudioContent) - return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + return CallToolResult( + content=[TextContent(type="text", text=f"{result.model}: {audio.mime_type} {audio.data}")] + ) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -517,14 +542,16 @@ async def sampling_callback( ) -> CreateMessageResult: return CreateMessageResult( role="assistant", - content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + content=AudioContent(type="audio", data="aGVsbG8=", mimeType="audio/wav"), model="mock-audio-1", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("speak", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="mock-audio-1: audio/wav aGVsbG8=")]) + ) @requirement("sampling:message:content-cardinality") @@ -535,7 +562,7 @@ async def test_create_message_with_list_valued_message_content_reaches_callback( async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="caption", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "caption" @@ -544,15 +571,15 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara SamplingMessage( role="user", content=[ - TextContent(text="Caption this image."), - ImageContent(data="aW1n", mime_type="image/png"), + TextContent(type="text", text="Caption this image."), + ImageContent(type="image", data="aW1n", mimeType="image/png"), ], ) ], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(text=result.content.text)]) + return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -563,13 +590,13 @@ async def sampling_callback( content = params.messages[0].content assert isinstance(content, list) return CreateMessageResult( - role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + role="assistant", content=TextContent(type="text", text=f"{len(content)} blocks"), model="mock-llm-1" ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("caption", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="2 blocks")])) assert received == snapshot( [ CreateMessageRequestParams( @@ -578,12 +605,12 @@ async def sampling_callback( SamplingMessage( role="user", content=[ - TextContent(text="Caption this image."), - ImageContent(data="aW1n", mime_type="image/png"), + TextContent(type="text", text="Caption this image."), + ImageContent(type="image", data="aW1n", mimeType="image/png"), ], ) ], - max_tokens=100, + maxTokens=100, ) ] ) @@ -601,7 +628,7 @@ async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejecte async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "continue_tools" @@ -610,17 +637,23 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara messages=[ SamplingMessage( role="assistant", - content=[ToolUseContent(id="call-1", name="weather", input={})], + content=[ToolUseContent(type="tool_use", id="call-1", name="weather", input={})], ), SamplingMessage( role="user", - content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + content=[ + ToolResultContent( + type="tool_result", + toolUseId="call-WRONG", + content=[TextContent(type="text", text="42")], + ) + ], ), ], max_tokens=100, ) except ValueError as exc: - return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + return CallToolResult(content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")]) raise NotImplementedError # the validator rejects the malformed messages before sending server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -638,7 +671,8 @@ async def sampling_callback( CallToolResult( content=[ TextContent( - text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + type="text", + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match", ) ] ) @@ -657,17 +691,17 @@ async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_valida async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask_model" try: await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Two thoughts, please."))], max_tokens=100, ) except pydantic.ValidationError as exc: - return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + return CallToolResult(content=[TextContent(type="text", text=type(exc).__name__)]) raise NotImplementedError # the array-content result fails server-side parsing every time server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) @@ -677,11 +711,11 @@ async def sampling_callback( ) -> CreateMessageResultWithTools: return CreateMessageResultWithTools( role="assistant", - content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + content=[TextContent(type="text", text="First thought."), TextContent(type="text", text="Second thought.")], model="mock-llm-1", ) async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index a9c83d641d..c08fc02dd2 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -63,14 +63,14 @@ async def list_tools( ) -> types.ListToolsResult: return types.ListToolsResult( tools=[ - types.Tool(name="block", input_schema={"type": "object"}), - types.Tool(name="echo", input_schema={"type": "object"}), + types.Tool(name="block", inputSchema={"type": "object"}), + types.Tool(name="echo", inputSchema={"type": "object"}), ] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: if params.name == "echo": - return CallToolResult(content=[TextContent(text="still alive")]) + return CallToolResult(content=[TextContent(type="text", text="still alive")]) await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable @@ -82,7 +82,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("echo", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="still alive")])) @requirement("protocol:timeout:session-default") diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index e8053fbaa7..25eee750cd 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -35,20 +35,22 @@ async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: return types.ListToolsResult( - tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + tools=[types.Tool(name="add", description="Add two integers.", inputSchema={"type": "object"})] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "add" assert params.arguments is not None - return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + return CallToolResult( + content=[TextContent(type="text", text=str(params.arguments["a"] + params.arguments["b"]))] + ) server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) - assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="5")])) @requirement("tools:call:is-error") @@ -61,7 +63,7 @@ async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "flux" - return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + return CallToolResult(content=[TextContent(type="text", text="the flux capacitor is offline")], isError=True) server = Server("errors", on_call_tool=call_tool) @@ -69,7 +71,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("flux", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + CallToolResult(content=[TextContent(type="text", text="the flux capacitor is offline")], isError=True) ) @@ -125,13 +127,13 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa Tool( name="add", description="Add two integers.", - input_schema={ + inputSchema={ "type": "object", "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, "required": ["a", "b"], }, ), - Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + Tool(name="reset", description="Reset the calculator.", inputSchema={"type": "object"}), ] ) @@ -146,13 +148,13 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa Tool( name="add", description="Add two integers.", - input_schema={ + inputSchema={ "type": "object", "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, "required": ["a", "b"], }, ), - Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + Tool(name="reset", description="Reset the calculator.", inputSchema={"type": "object"}), ] ) ) @@ -187,12 +189,12 @@ async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Con } async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + return ListToolsResult(tools=[Tool(name="typed", inputSchema=schema)]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "typed" assert params.arguments == {"count": 3, "options": {"verbose": True}} - return CallToolResult(content=[TextContent(text="ok")]) + return CallToolResult(content=[TextContent(type="text", text="ok")]) server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) @@ -201,7 +203,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) assert listed.tools[0].input_schema == schema - assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert called == snapshot(CallToolResult(content=[TextContent(type="text", text="ok")])) @requirement("tools:list:metadata") @@ -212,10 +214,10 @@ async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: name="annotated", title="Annotated tool", description="A tool carrying every optional field.", - input_schema={"type": "object"}, - output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, - icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], - annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", readOnlyHint=True, idempotentHint=True), _meta={"example.com/source": "interaction-suite"}, ) @@ -234,10 +236,10 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa name="annotated", title="Annotated tool", description="A tool carrying every optional field.", - input_schema={"type": "object"}, - output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, - icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], - annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", readOnlyHint=True, idempotentHint=True), _meta={"example.com/source": "interaction-suite"}, ) ] @@ -258,18 +260,21 @@ async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: """ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="render", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "render" return CallToolResult( content=[ - TextContent(text="all five content block types"), - ImageContent(data="aW1n", mime_type="image/png"), - AudioContent(data="YXVk", mime_type="audio/wav"), - ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + TextContent(type="text", text="all five content block types"), + ImageContent(type="image", data="aW1n", mimeType="image/png"), + AudioContent(type="audio", data="YXVk", mimeType="audio/wav"), + ResourceLink( + type="resource_link", name="report", uri="resource://reports/1", description="The full report" + ), EmbeddedResource( - resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + type="resource", + resource=TextResourceContents(uri="resource://reports/1", mimeType="text/plain", text="contents"), ), ] ) @@ -282,12 +287,15 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot( CallToolResult( content=[ - TextContent(text="all five content block types"), - ImageContent(data="aW1n", mime_type="image/png"), - AudioContent(data="YXVk", mime_type="audio/wav"), - ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + TextContent(type="text", text="all five content block types"), + ImageContent(type="image", data="aW1n", mimeType="image/png"), + AudioContent(type="audio", data="YXVk", mimeType="audio/wav"), + ResourceLink( + type="resource_link", name="report", uri="resource://reports/1", description="The full report" + ), EmbeddedResource( - resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + type="resource", + resource=TextResourceContents(uri="resource://reports/1", mimeType="text/plain", text="contents"), ), ] ) @@ -299,18 +307,20 @@ async def test_call_tool_structured_content(connect: Connect) -> None: """A tool result carrying structured content alongside content delivers both to the client.""" async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="sum", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "sum" - return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + return CallToolResult(content=[TextContent(type="text", text="the sum is 5")], structuredContent={"sum": 5}) server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server) as client: result = await client.call_tool("sum", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="the sum is 5")], structuredContent={"sum": 5}) + ) @requirement("tools:call:concurrent") @@ -327,7 +337,7 @@ async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> results: dict[str, CallToolResult] = {} async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "echo" @@ -337,7 +347,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara started.append(tag) started_events[tag].set() await release.wait() - return CallToolResult(content=[TextContent(text=tag)]) + return CallToolResult(content=[TextContent(type="text", text=tag)]) server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -359,8 +369,8 @@ async def call_and_record(tag: str) -> None: assert sorted(started) == ["first", "second"] assert results == snapshot( { - "first": CallToolResult(content=[TextContent(text="first")]), - "second": CallToolResult(content=[TextContent(text="second")]), + "first": CallToolResult(content=[TextContent(type="text", text="first")]), + "second": CallToolResult(content=[TextContent(type="text", text="second")]), } ) @@ -376,8 +386,8 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa tools=[ Tool( name="forecast", - input_schema={"type": "object"}, - output_schema={ + inputSchema={"type": "object"}, + outputSchema={ "type": "object", "properties": {"temperature": {"type": "number"}}, "required": ["temperature"], @@ -388,7 +398,9 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "forecast" - return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + return CallToolResult( + content=[TextContent(type="text", text="warm")], structuredContent={"temperature": "warm"} + ) server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) @@ -414,8 +426,8 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa tools=[ Tool( name="forecast", - input_schema={"type": "object"}, - output_schema={ + inputSchema={"type": "object"}, + outputSchema={ "type": "object", "properties": {"temperature": {"type": "number"}}, "required": ["temperature"], @@ -427,7 +439,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "forecast" return CallToolResult( - content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + content=[TextContent(type="text", text="boom")], structuredContent={"temperature": "warm"}, isError=True ) server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) @@ -437,7 +449,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("forecast", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + CallToolResult( + content=[TextContent(type="text", text="boom")], structuredContent={"temperature": "warm"}, isError=True + ) ) @@ -453,15 +467,15 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa tools=[ Tool( name="forecast", - input_schema={"type": "object"}, - output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, ) ] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "forecast" - return CallToolResult(content=[TextContent(text="warm")]) + return CallToolResult(content=[TextContent(type="text", text="warm")]) server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) @@ -490,15 +504,15 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa tools=[ Tool( name="forecast", - input_schema={"type": "object"}, - output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, ) ] ) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "forecast" - return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + return CallToolResult(content=[TextContent(type="text", text="21 C")], structuredContent={"temperature": 21}) server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) @@ -508,5 +522,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara second = await client.call_tool("forecast", {}) assert list_calls == ["called"] - assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert first == snapshot( + CallToolResult(content=[TextContent(type="text", text="21 C")], structuredContent={"temperature": 21}) + ) assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 0f9c58aa7a..f216884ecf 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -48,11 +48,11 @@ def _echo_server() -> Server: async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + return types.ListToolsResult(tools=[types.Tool(name="echo", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "echo" - return CallToolResult(content=[TextContent(text="ok")]) + return CallToolResult(content=[TextContent(type="text", text="ok")]) return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index 26556fea7a..0ead93bf6c 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -54,7 +54,9 @@ async def collect(params: LoggingMessageNotificationParams) -> None: result = await client.call_tool("narrate", {}) advertised_logging = client.initialize_result.capabilities.logging - assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="done")], structuredContent={"result": "done"}) + ) assert received == snapshot( [ LoggingMessageNotificationParams(level="debug", data="d"), @@ -89,7 +91,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) result = await client.call_tool("crunch", {}, progress_callback=on_progress) assert result == snapshot( - CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + CallToolResult(content=[TextContent(type="text", text="crunched")], structuredContent={"result": "crunched"}) ) assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) @@ -145,7 +147,7 @@ async def collect(message: IncomingMessage) -> None: result = await client.call_tool("mill", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + CallToolResult(content=[TextContent(type="text", text="milled")], structuredContent={"result": "milled"}) ) assert received == snapshot( [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] @@ -185,7 +187,7 @@ async def answer_form(context: ClientRequestContext, params: ElicitRequestParams ElicitRequestFormParams( _meta={}, message="Where to?", - requested_schema={ + requestedSchema={ "properties": { "destination": {"title": "Destination", "type": "string"}, "window_seat": {"title": "Window Seat", "type": "boolean"}, @@ -199,8 +201,8 @@ async def answer_form(context: ClientRequestContext, params: ElicitRequestParams ) assert result == snapshot( CallToolResult( - content=[TextContent(text="accept: Lisbon window=True")], - structured_content={"result": "accept: Lisbon window=True"}, + content=[TextContent(type="text", text="accept: Lisbon window=True")], + structuredContent={"result": "accept: Lisbon window=True"}, ) ) @@ -229,8 +231,8 @@ async def show_config(ctx: Context) -> str: assert result == snapshot( CallToolResult( - content=[TextContent(text="text/plain: 'theme = dark'")], - structured_content={"result": "text/plain: 'theme = dark'"}, + content=[TextContent(type="text", text="text/plain: 'theme = dark'")], + structuredContent={"result": "text/plain: 'theme = dark'"}, ) ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 2095f086d4..657cc2a6c5 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -68,7 +68,7 @@ def greet(name: str) -> str: assert result == snapshot( GetPromptResult( description="A personalised greeting.", - messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + messages=[PromptMessage(role="user", content=TextContent(type="text", text="Say hello to Ada."))], ) ) @@ -156,7 +156,7 @@ def review(code: str, style: str = "pep8") -> str: assert result == snapshot( GetPromptResult( description="Review a snippet of code against a style guide.", - messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + messages=[PromptMessage(role="user", content=TextContent(type="text", text="Review x = 1 per pep8."))], ) ) @@ -190,6 +190,6 @@ def greet_second() -> str: assert result == snapshot( GetPromptResult( description="The first registration; this is the one that wins.", - messages=[PromptMessage(role="user", content=TextContent(text="first"))], + messages=[PromptMessage(role="user", content=TextContent(type="text", text="first"))], ) ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index 57b0fdc86d..5066095541 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -35,7 +35,7 @@ def app_config() -> str: assert result == snapshot( ReadResourceResult( - contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + contents=[TextResourceContents(uri="config://app", mimeType="text/plain", text="theme = dark")] ) ) @@ -70,19 +70,19 @@ def user_profile(user_id: str) -> str: name="app_config", uri="config://app", description="The application configuration.", - mime_type="text/plain", + mimeType="text/plain", ) ] ) ) assert templates == snapshot( ListResourceTemplatesResult( - resource_templates=[ + resourceTemplates=[ ResourceTemplate( name="user_profile", - uri_template="users://{user_id}/profile", + uriTemplate="users://{user_id}/profile", description="A user's profile.", - mime_type="text/plain", + mimeType="text/plain", ) ] ) @@ -105,7 +105,7 @@ def user_profile(user_id: str) -> str: assert result == snapshot( ReadResourceResult( - contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + contents=[TextResourceContents(uri="users://42/profile", mimeType="text/plain", text="profile for 42")] ) ) @@ -179,5 +179,5 @@ def config_second() -> str: assert [resource.uri for resource in listed.resources] == ["config://app"] assert listed.resources[0].name == "config_first" assert result == snapshot( - ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mimeType="text/plain", text="first")]) ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index 05135c1286..a4cf9ca348 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -43,7 +43,9 @@ def add(a: int, b: int) -> str: async with connect(mcp) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) - assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="5")], structuredContent={"result": "5"}) + ) @requirement("mcpserver:tool:schema-variants") @@ -69,7 +71,8 @@ def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Fie assert result == snapshot( CallToolResult( - content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + content=[TextContent(type="text", text="fast at (3, 4) x5")], + structuredContent={"result": "fast at (3, 4) x5"}, ) ) @@ -93,7 +96,7 @@ def explode() -> str: result = await client.call_tool("explode", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + CallToolResult(content=[TextContent(type="text", text="Error executing tool explode: boom")], isError=True) ) @@ -110,7 +113,9 @@ def flux() -> str: result = await client.call_tool("flux", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + CallToolResult( + content=[TextContent(type="text", text="Error executing tool flux: flux capacitor offline")], isError=True + ) ) @@ -130,7 +135,9 @@ def add() -> None: async with connect(mcp) as client: result = await client.call_tool("nope", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="Unknown tool: nope")], isError=True) + ) @requirement("mcpserver:tool:output-schema:model") @@ -168,15 +175,16 @@ def get_weather() -> Weather: CallToolResult( content=[ TextContent( + type="text", text="""\ { "temperature": 22.5, "conditions": "sunny" }\ -""" +""", ) ], - structured_content={"temperature": 22.5, "conditions": "sunny"}, + structuredContent={"temperature": 22.5, "conditions": "sunny"}, ) ) @@ -206,8 +214,12 @@ def primes() -> list[int]: ) assert result == snapshot( CallToolResult( - content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], - structured_content={"result": [2, 3, 5]}, + content=[ + TextContent(type="text", text="2"), + TextContent(type="text", text="3"), + TextContent(type="text", text="5"), + ], + structuredContent={"result": [2, 3, 5]}, ) ) @@ -256,11 +268,11 @@ class Weather(BaseModel): @mcp.tool() def mismatched() -> Annotated[CallToolResult, Weather]: - return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + return CallToolResult(content=[TextContent(type="text", text="oops")], structuredContent={"nope": True}) @mcp.tool() def missing() -> Annotated[CallToolResult, Weather]: - return CallToolResult(content=[TextContent(text="oops")]) + return CallToolResult(content=[TextContent(type="text", text="oops")]) async with connect(mcp) as client: mismatched_result = await client.call_tool("mismatched", {}) @@ -305,7 +317,7 @@ def echo_second() -> str: assert [tool.name for tool in listed.tools] == ["echo"] assert result == snapshot( - CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + CallToolResult(content=[TextContent(type="text", text="first")], structuredContent={"result": "first"}) ) @@ -340,7 +352,9 @@ def bad() -> str: result = await client.call_tool("bad name!", {}) assert [tool.name for tool in listed.tools] == ["bad name!"] - assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="ok")], structuredContent={"result": "ok"}) + ) @requirement("mcpserver:tool:url-elicitation-error") @@ -360,7 +374,7 @@ def read_files() -> str: ElicitRequestURLParams( message="Authorization required for your files.", url="https://example.com/oauth/authorize", - elicitation_id="auth-001", + elicitationId="auth-001", ) ] ) diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py index 5977cc3e99..633b43f340 100644 --- a/tests/interaction/transports/_stdio_server.py +++ b/tests/interaction/transports/_stdio_server.py @@ -29,7 +29,7 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | tools=[ Tool( name="echo", - input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + inputSchema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, ) ] ) @@ -40,7 +40,7 @@ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> assert params.arguments is not None text = params.arguments["text"] await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") - return CallToolResult(content=[TextContent(text=text)]) + return CallToolResult(content=[TextContent(type="text", text=text)]) async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py index 65ed03f1e4..96b70c37e6 100644 --- a/tests/interaction/transports/test_client_transport_http.py +++ b/tests/interaction/transports/test_client_transport_http.py @@ -31,12 +31,12 @@ def _tooled_server() -> Server: """A low-level server with one echo tool, used by every test in this file.""" async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", inputSchema={"type": "object"})]) async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "echo" assert params.arguments is not None - return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + return CallToolResult(content=[TextContent(type="text", text=str(params.arguments["text"]))]) return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -147,9 +147,9 @@ async def call(n: int) -> None: assert results == snapshot( { - 1: CallToolResult(content=[TextContent(text="1")]), - 2: CallToolResult(content=[TextContent(text="2")]), - 3: CallToolResult(content=[TextContent(text="3")]), + 1: CallToolResult(content=[TextContent(type="text", text="1")]), + 2: CallToolResult(content=[TextContent(type="text", text="2")]), + 3: CallToolResult(content=[TextContent(type="text", text="3")]), } ) tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py index c428fe2d68..e5d75a9f7c 100644 --- a/tests/interaction/transports/test_flows.py +++ b/tests/interaction/transports/test_flows.py @@ -89,7 +89,7 @@ async def record(request: httpx.Request) -> None: assert {tool.name for tool in first_result.tools} == {"echo"} assert second_result == snapshot( - CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + CallToolResult(content=[TextContent(type="text", text="again")], structuredContent={"result": "again"}) ) distinct = set(session_ids) assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" @@ -122,8 +122,8 @@ def echo(text: str) -> str: sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) assert shttp_result == snapshot( - CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + CallToolResult(content=[TextContent(type="text", text="via http")], structuredContent={"result": "via http"}) ) assert sse_result == snapshot( - CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + CallToolResult(content=[TextContent(type="text", text="via sse")], structuredContent={"result": "via sse"}) ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 85e64ded42..be0d7cbda2 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -54,7 +54,7 @@ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> assert params.name == "narrate" await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) await ctx.session.send_resource_updated("file:///watched.txt") - return CallToolResult(content=[TextContent(text="done")]) + return CallToolResult(content=[TextContent(type="text", text="done")]) async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: """Registered so the logging capability is advertised; the client never sets a level.""" diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index c7945d56c3..a23311b39c 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -291,7 +291,7 @@ async def call() -> None: await done.wait() assert result == snapshot( - [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + [CallToolResult(content=[TextContent(type="text", text="resumed")], structuredContent={"result": "resumed"})] ) assert received == snapshot(["before close", "after close"]) @@ -368,5 +368,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: result = await second.send_request( call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) ) - assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="done")], structuredContent={"result": "done"}) + ) assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py index a926c3e8a2..1285d4d81e 100644 --- a/tests/interaction/transports/test_hosting_session.py +++ b/tests/interaction/transports/test_hosting_session.py @@ -32,7 +32,7 @@ def _server() -> Server: """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", inputSchema={"type": "object"})]) return Server("hosted", on_list_tools=list_tools) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 27cc65de42..d805e64933 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -82,7 +82,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: errlog.seek(0) captured_stderr = errlog.read() - assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="across\nprocesses")])) # stdio carries one ordered server→client stream, so the same notification-before-response # guarantee holds here as for the in-memory transport. assert received == snapshot( diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index d38e2a0bb3..a0b7fb2bda 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -71,7 +71,7 @@ async def test_tool_call_over_streamable_http_with_json_responses() -> None: result = await client.call_tool("echo", {"text": "as json"}) assert result == snapshot( - CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + CallToolResult(content=[TextContent(type="text", text="as json")], structuredContent={"result": "as json"}) ) @@ -83,10 +83,10 @@ async def test_tool_calls_over_stateless_streamable_http() -> None: second = await client.call_tool("echo", {"text": "second"}) assert first == snapshot( - CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + CallToolResult(content=[TextContent(type="text", text="first")], structuredContent={"result": "first"}) ) assert second == snapshot( - CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + CallToolResult(content=[TextContent(type="text", text="second")], structuredContent={"result": "second"}) ) @@ -129,7 +129,7 @@ async def collect(message: IncomingMessage) -> None: await resource_update_seen.wait() assert result == snapshot( - CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + CallToolResult(content=[TextContent(type="text", text="announced")], structuredContent={"result": "announced"}) ) # The related log notification rides the call's stream; the unrelated resource-updated # notification rides the standalone stream. Both arrive, nothing else does. @@ -163,6 +163,8 @@ async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> result = await client.call_tool("ask", {}) assert result == snapshot( - CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + CallToolResult( + content=[TextContent(type="text", text="confirmed=True")], structuredContent={"result": "confirmed=True"} + ) ) assert [params.message for params in asked] == snapshot(["Proceed?"]) From 92f2bb4077871cfba22f73f4efed53239aa2a830 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 13:49:33 +0000 Subject: [PATCH 03/19] =?UTF-8?q?backport:=20harness=20H1=20=E2=80=94=20?= =?UTF-8?q?=5Fhelpers.py:=20local=20ReadStream/WriteStream=20aliases,=20Re?= =?UTF-8?q?cordingTransport=E2=86=92Recording=20stream-pair=20wrapper?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/interaction/_helpers.py | 37 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 25833b0ca5..66ea6ab669 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -1,7 +1,7 @@ """Shared helpers for the interaction suite. Keep this module small: it exists only for (a) types that every test would otherwise have to -assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +assemble from the SDK's internals to annotate a client callback, and (b) the recording wrapper used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses them. """ @@ -9,9 +9,9 @@ from types import TracebackType import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self -from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ClientResult, ServerNotification, ServerRequest @@ -24,11 +24,17 @@ IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception """Everything a client message handler can receive.""" +ReadStream = MemoryObjectReceiveStream[SessionMessage | Exception] +WriteStream = MemoryObjectSendStream[SessionMessage] +"""Local aliases for the v1 SDK's session-stream types (v1 has no exported `ReadStream`/ +`WriteStream` names); exported so wire-level / scripted-peer tests can annotate without +reaching into anyio.""" + class _RecordingReadStream: """Delegates to a read stream, appending every received message to a log.""" - def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + def __init__(self, inner: ReadStream, log: list[SessionMessage | Exception]) -> None: self._inner = inner self._log = log @@ -62,7 +68,7 @@ async def __aexit__( class _RecordingWriteStream: """Delegates to a write stream, appending every sent message to a log.""" - def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + def __init__(self, inner: WriteStream, log: list[SessionMessage]) -> None: self._inner = inner self._log = log @@ -83,25 +89,22 @@ async def __aexit__( return None -class RecordingTransport: - """Wraps a Transport and records every message crossing the client's transport boundary. +class Recording: + """Wraps a (read, write) stream pair and records every message crossing it. `sent` holds everything the client wrote towards the server; `received` holds everything the server delivered to the client. The recording sits at the transport seam -- the exact payloads a real transport would serialise -- and never touches the session, so wire-level assertions written against it survive changes to the receive path. + + v1 has no `Transport` abstraction; tests insert this between + `create_client_server_memory_streams()` and `ClientSession`. """ - def __init__(self, inner: Transport) -> None: - self.inner = inner + def __init__(self, read: ReadStream, write: WriteStream) -> None: self.sent: list[SessionMessage] = [] self.received: list[SessionMessage | Exception] = [] - - async def __aenter__(self) -> TransportStreams: - read_stream, write_stream = await self.inner.__aenter__() - return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None - ) -> bool | None: - return await self.inner.__aexit__(exc_type, exc_val, exc_tb) + # Duck-typed stand-ins for the anyio stream classes; ClientSession only calls + # .receive()/.send()/.aclose() so the runtime contract holds. + self.read: ReadStream = _RecordingReadStream(read, self.received) # type: ignore[assignment] + self.write: WriteStream = _RecordingWriteStream(write, self.sent) # type: ignore[assignment] From a507771e21416f826a709b6580b04b22e3f19a53 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 13:53:38 +0000 Subject: [PATCH 04/19] =?UTF-8?q?backport:=20harness=20H2=20=E2=80=94=20?= =?UTF-8?q?=5Fconnect.py=20imports/annotations=20+=20conftest=20marker=20r?= =?UTF-8?q?egistration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - conftest.py: add pytest_configure registering the 'requirement' marker (round-1 adv-3 S1 fix; without this every file is a PytestUnknownMarkWarning collection error under filterwarnings=['error']). - _connect.py imports: drop Client/MCPServer/jsonrpc_message_adapter (not in v1); add timedelta, ClientSession, FastMCP. Mount KEPT (adversarial-v2-gate S1: SSE build_sse_app needs it). - _connect.py annotations: Connect Protocol + all factory signatures retyped to Server[Any]|FastMCP / timedelta / ClientSession; kwarg order matches v1 ClientSession.__init__; add _lowlevel() helper. Function bodies untouched per plan-v2 H2; the 7 dead Client/MCPServer/adapter refs carry temporary noqa:F821 until H3/H4/H5 rewrite them. --- tests/interaction/_connect.py | 90 +++++++++++++++++++---------------- tests/interaction/conftest.py | 7 +++ 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 1faf4aa8d6..f923931ef0 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -1,14 +1,16 @@ """Transport-parametrized connection factories for the interaction suite. The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body -runs over each transport without naming any of them: the factory is a drop-in replacement for -constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the -server's real Starlette app through the in-process streaming bridge, so the full transport layer -(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +runs over each transport without naming any of them: the factory yields an initialized +`ClientSession` connected to the given server. v1 has no high-level `Client` class — +`ClientSession` *is* the client. The HTTP factories drive the server's real Starlette app through +the in-process streaming bridge, so the full transport layer (session ids, SSE encoding, session +management) runs with no sockets, threads, or subprocesses. """ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, asynccontextmanager +from datetime import timedelta from typing import Any, Protocol import httpx @@ -18,14 +20,13 @@ from starlette.responses import Response from starlette.routing import Mount, Route -from mcp.client.client import Client -from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -38,7 +39,6 @@ JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, - jsonrpc_message_adapter, ) from tests.interaction.transports._bridge import StreamingASGITransport @@ -52,40 +52,50 @@ NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) +def _lowlevel(server: Server[Any] | FastMCP) -> Server[Any]: + """Return the lowlevel `Server` for either flavour. + + Reaching `FastMCP._mcp_server` is the v1 idiom — `mcp.shared.memory` itself does exactly + this (with the same `# type: ignore`). + """ + return server._mcp_server if isinstance(server, FastMCP) else server # type: ignore[reportPrivateUsage] + + class Connect(Protocol): - """Connect a Client to a server over the transport selected by the `connect` fixture. + """Connect a `ClientSession` to a server over the transport selected by the `connect` fixture. - Accepts the same keyword arguments as `Client` and yields the connected client. + Accepts the same callback keyword arguments as `ClientSession` and yields the connected, + initialized session. """ def __call__( self, - server: Server | MCPServer, + server: Server[Any] | FastMCP, *, - read_timeout_seconds: float | None = None, + read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, - elicitation_callback: ElicitationFnT | None = None, - ) -> AbstractAsyncContextManager[Client]: ... + ) -> AbstractAsyncContextManager[ClientSession]: ... @asynccontextmanager async def connect_in_memory( - server: Server | MCPServer, + server: Server[Any] | FastMCP, *, - read_timeout_seconds: float | None = None, + read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, - elicitation_callback: ElicitationFnT | None = None, -) -> AsyncIterator[Client]: - """Yield a Client connected to the server over the in-memory transport.""" - async with Client( +) -> AsyncIterator[ClientSession]: + """Yield an initialized `ClientSession` connected to the server over the in-memory transport.""" + async with Client( # noqa: F821 -- body rewritten in H3 server, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, @@ -100,21 +110,21 @@ async def connect_in_memory( @asynccontextmanager async def connect_over_streamable_http( - server: Server | MCPServer, + server: Server[Any] | FastMCP, *, stateless_http: bool = False, json_response: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, - read_timeout_seconds: float | None = None, + read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, - elicitation_callback: ElicitationFnT | None = None, -) -> AsyncIterator[Client]: - """Yield a Client connected to the server's streamable HTTP app, entirely in process. +) -> AsyncIterator[ClientSession]: + """Yield an initialized `ClientSession` over the server's streamable HTTP app, entirely in process. With the defaults this is the matrix leg (stateful sessions, SSE responses); the transport-specific tests pass `stateless_http` or `json_response` to select the other @@ -131,7 +141,7 @@ async def connect_over_streamable_http( async with ( server.session_manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, - Client( + Client( # noqa: F821 -- body rewritten in H4 streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, @@ -147,7 +157,7 @@ async def connect_over_streamable_http( @asynccontextmanager async def mounted_app( - server: Server | MCPServer, + server: Server[Any] | FastMCP, *, stateless_http: bool = False, json_response: bool = False, @@ -172,7 +182,7 @@ async def mounted_app( DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the localhost auto-enable behaviour) to test the protection itself. """ - lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H5 app = lowlevel.streamable_http_app( stateless_http=stateless_http, json_response=json_response, @@ -200,15 +210,15 @@ async def client_via_http( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, elicitation_callback: ElicitationFnT | None = None, -) -> AsyncIterator[Client]: - """Connect a `Client` over an already-mounted streamable HTTP app. +) -> AsyncIterator[ClientSession]: + """Connect a `ClientSession` over an already-mounted streamable HTTP app. Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a client-driven assertion can sit alongside raw-httpx assertions in the same test. The underlying `httpx.AsyncClient` is left open when the `Client` exits. """ transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - async with Client( + async with Client( # noqa: F821 -- body rewritten in H4 transport, logging_callback=logging_callback, message_handler=message_handler, @@ -219,7 +229,7 @@ async def client_via_http( def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" - return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] # noqa: F821 -- body rewritten in H3 async def post_jsonrpc( @@ -289,7 +299,7 @@ async def initialize_via_http(http: httpx.AsyncClient) -> str: return session_id -def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: +def build_sse_app(server: Server[Any] | FastMCP) -> tuple[Starlette, SseServerTransport]: """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which @@ -299,7 +309,7 @@ def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTrans sse = SseServerTransport( "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H3 async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): @@ -317,17 +327,17 @@ async def handle_sse(request: Request) -> Response: @asynccontextmanager async def connect_over_sse( - server: Server | MCPServer, + server: Server[Any] | FastMCP, *, - read_timeout_seconds: float | None = None, + read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, - elicitation_callback: ElicitationFnT | None = None, -) -> AsyncIterator[Client]: - """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" +) -> AsyncIterator[ClientSession]: + """Yield an initialized `ClientSession` over the server's legacy SSE transport, entirely in process.""" app, _ = build_sse_app(server) def httpx_client_factory( @@ -347,7 +357,7 @@ def httpx_client_factory( ) transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) - async with Client( + async with Client( # noqa: F821 -- body rewritten in H3 transport, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index c2ace45077..6319bc89b7 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -4,6 +4,13 @@ from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line( + "markers", "requirement(id): tag a test as covering an entry in tests/interaction/_requirements.py" + ) + + _FACTORIES: dict[str, Connect] = { "in-memory": connect_in_memory, "streamable-http": connect_over_streamable_http, From f7daf85f42da1f277213d3848793a24a69f77c26 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 13:56:17 +0000 Subject: [PATCH 05/19] =?UTF-8?q?backport:=20harness=20H3=20=E2=80=94=20co?= =?UTF-8?q?nnect=5Fin=5Fmemory,=20build=5Fsse=5Fapp,=20connect=5Fover=5Fss?= =?UTF-8?q?e,=20parse=5Fsse=5Fmessages,=20initialize=5Fbody=20bodies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/interaction/_connect.py | 58 ++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index f923931ef0..febee2bfaa 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -31,6 +31,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -94,8 +95,13 @@ async def connect_in_memory( message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, ) -> AsyncIterator[ClientSession]: - """Yield an initialized `ClientSession` connected to the server over the in-memory transport.""" - async with Client( # noqa: F821 -- body rewritten in H3 + """Yield an initialized `ClientSession` connected to the server over the in-memory transport. + + This is exactly `mcp.shared.memory.create_connected_server_and_client_session` — the + canonical v1 in-memory idiom — re-exported under the suite's `Connect` shape so the + transport matrix can parametrize over it. + """ + async with create_connected_server_and_client_session( server, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, @@ -104,8 +110,8 @@ async def connect_in_memory( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, - ) as client: - yield client + ) as session: + yield session @asynccontextmanager @@ -229,7 +235,7 @@ async def client_via_http( def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" - return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] # noqa: F821 -- body rewritten in H3 + return [JSONRPCMessage.model_validate_json(event.data) for event in events if event.data] async def post_jsonrpc( @@ -268,9 +274,9 @@ def base_headers(*, session_id: str | None = None) -> dict[str, str]: def initialize_body(request_id: int = 1) -> dict[str, object]: """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" params = InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, + protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ClientCapabilities(), - client_info=Implementation(name="raw", version="0.0.0"), + clientInfo=Implementation(name="raw", version="0.0.0"), ) return JSONRPCRequest( jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) @@ -302,17 +308,15 @@ async def initialize_via_http(http: httpx.AsyncClient) -> str: def build_sse_app(server: Server[Any] | FastMCP) -> tuple[Starlette, SseServerTransport]: """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. - `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + `FastMCP.sse_app()` exists but does not expose the underlying `SseServerTransport`, which the SSE-specific tests need; building the app explicitly here gives both server flavours the same routing while keeping that handle. """ - sse = SseServerTransport( - "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) - ) - lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H3 + sse = SseServerTransport("/messages/", security_settings=NO_DNS_REBINDING_PROTECTION) + lowlevel = _lowlevel(server) async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): # type: ignore[reportPrivateUsage] await lowlevel.run(read, write, lowlevel.create_initialization_options()) return Response() @@ -356,15 +360,19 @@ def httpx_client_factory( auth=auth, ) - transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) - async with Client( # noqa: F821 -- body rewritten in H3 - transport, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client: - yield client + async with ( + sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) as (read, write), + ClientSession( + read, + write, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as session, + ): + await session.initialize() + yield session From 7a936a2cee44f4e2fb79a1f5f80a29e6b4cb4f6f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 14:05:00 +0000 Subject: [PATCH 06/19] =?UTF-8?q?backport:=20harness=20H4=20=E2=80=94=20bu?= =?UTF-8?q?ild=5Fstreamable=5Fhttp=5Fapp=20(no-auth),=20connect=5Fover=5Fs?= =?UTF-8?q?treamable=5Fhttp,=20client=5Fvia=5Fhttp;=20gate:=203=20smoke=20?= =?UTF-8?q?legs=20pass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also: - pyproject: add [tool.inline-snapshot] default-flags=["disable"] (matches main; without it -n0 runs use inline-snapshot active mode whose pydantic comparison mishandles extra="allow" models) - conftest: suppress PytestUnraisableExceptionWarning/ResourceWarning — v1 streamable-HTTP server transport leaks memory streams on teardown (e.g. _handle_get_request only closes sse_stream_reader on the exception path); fixes are src/-side on main, out of scope here - test_ping.py: convert to v1 decorator pattern using session-access pattern B (request_ctx contextvar) per the file→pattern assignment Gate: tests/interaction/lowlevel/test_ping.py 6/6 pass (both tests × 3 transport legs). --- pyproject.toml | 5 ++ tests/interaction/_connect.py | 89 ++++++++++++++++++++----- tests/interaction/conftest.py | 8 +++ tests/interaction/lowlevel/test_ping.py | 35 ++++++---- 4 files changed, 105 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0d424a2841..3b836470ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,11 @@ members = ["examples/clients/*", "examples/servers/*", "examples/snippets"] [tool.uv.sources] mcp = { workspace = true } +[tool.inline-snapshot] +# `snapshot(x)` becomes the identity (plain `==`); regenerate with `--inline-snapshot=fix`. +default-flags = ["disable"] +format-command = "ruff format --stdin-filename {filename}" + [tool.pytest.ini_options] log_cli = true xfail_strict = true diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index febee2bfaa..43057c10f8 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -27,6 +27,7 @@ from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.server import StreamableHTTPASGIApp from mcp.server.sse import SseServerTransport from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -114,6 +115,51 @@ async def connect_in_memory( yield session +def build_streamable_http_app( + server: Server[Any] | FastMCP, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, +) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Assemble a streamable-HTTP Starlette app for either server flavour. + + v1's lowlevel `Server` has no `streamable_http_app()`; this follows + `FastMCP.streamable_http_app()` (`mcp/server/fastmcp/server.py`) so behaviour matches what a + v1 user would get from `FastMCP(..., **knobs).streamable_http_app()`. Returns the live + `StreamableHTTPSessionManager` alongside the app so the caller can enter `manager.run()` + (the in-process bridge does not drive Starlette lifespan) and so tests can reach + `manager._server_instances`. + + `/mcp` is mounted via `Route(path, endpoint=)` with no `methods=`, exactly + as FastMCP does — Starlette treats a class-instance endpoint as raw ASGI and matches all + verbs, which is what the transport requires. + """ + manager = StreamableHTTPSessionManager( + app=_lowlevel(server), + event_store=event_store, + json_response=json_response, + stateless=stateless_http, + security_settings=transport_security, + retry_interval=retry_interval, + ) + asgi = StreamableHTTPASGIApp(manager) + + routes: list[Route] = [] + # Auth routing (middleware, AS routes, RequireAuthMiddleware wrap, PRM routes) is added in + # H5; until then `auth` / `token_verifier` / `auth_server_provider` are accepted but ignored + # so callers that don't pass them work today. + assert auth is None and token_verifier is None and auth_server_provider is None, "auth branch lands in H5" + routes.append(Route("/mcp", endpoint=asgi)) + + return Starlette(routes=routes), manager + + @asynccontextmanager async def connect_over_streamable_http( server: Server[Any] | FastMCP, @@ -137,18 +183,20 @@ async def connect_over_streamable_http( server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so the client's reconnection wait is a no-op). """ - app = server.streamable_http_app( + app, manager = build_streamable_http_app( + server, stateless_http=stateless_http, json_response=json_response, event_store=event_store, retry_interval=retry_interval, - transport_security=NO_DNS_REBINDING_PROTECTION, ) async with ( - server.session_manager.run(), + manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, - Client( # noqa: F821 -- body rewritten in H4 - streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write, _get_session_id), + ClientSession( + read, + write, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, @@ -156,9 +204,10 @@ async def connect_over_streamable_http( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, - ) as client, + ) as session, ): - yield client + await session.initialize() + yield session @asynccontextmanager @@ -219,18 +268,22 @@ async def client_via_http( ) -> AsyncIterator[ClientSession]: """Connect a `ClientSession` over an already-mounted streamable HTTP app. - Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a - client-driven assertion can sit alongside raw-httpx assertions in the same test. The - underlying `httpx.AsyncClient` is left open when the `Client` exits. + Use with `mounted_app(...)` so several `ClientSession`s share the one session manager, or + so a client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the session exits. """ - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - async with Client( # noqa: F821 -- body rewritten in H4 - transport, - logging_callback=logging_callback, - message_handler=message_handler, - elicitation_callback=elicitation_callback, - ) as client: - yield client + async with ( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write, _get_session_id), + ClientSession( + read, + write, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as session, + ): + await session.initialize() + yield session def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index 6319bc89b7..3f389cedba 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -9,6 +9,14 @@ def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line( "markers", "requirement(id): tag a test as covering an entry in tests/interaction/_requirements.py" ) + # v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown + # (e.g. `_handle_get_request` only closes `sse_stream_reader` on the exception path; the + # session manager's per-session task-group cancel can race the per-request cleanup). v1's own + # tests run the transport in a separate process and so never observe these `__del__`-time + # ResourceWarnings; running in-process via the streaming bridge does. The fixes live in `src/` + # on `main` and are out of scope for this tests-only backport, so suppress here. + config.addinivalue_line("filterwarnings", "ignore::pytest.PytestUnraisableExceptionWarning") + config.addinivalue_line("filterwarnings", "ignore::ResourceWarning") _FACTORIES: dict[str, Connect] = { diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py index ed1f466809..c89b603043 100644 --- a/tests/interaction/lowlevel/test_ping.py +++ b/tests/interaction/lowlevel/test_ping.py @@ -1,10 +1,20 @@ -"""Ping interactions against the low-level Server, driven through the public Client API.""" +"""Ping interactions against the low-level Server, driven through the public ClientSession API. + +This file reaches the server session via the module-level `request_ctx` contextvar (pattern B +from the v1 backport's session-access spread). That contextvar is the mechanism behind +`Server.request_context`; reading it directly is a public-module-level name a v1 user can +import, and exercising it here covers the contextvar path the eventual v2 compatibility shims +must preserve. +""" + +from typing import Any import pytest from inline_snapshot import snapshot from mcp import types -from mcp.server import Server, ServerRequestContext +from mcp.server import Server +from mcp.server.lowlevel.server import request_ctx from mcp.types import CallToolResult, EmptyResult, TextContent from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -16,7 +26,7 @@ @requirement("ping:client-to-server") async def test_client_ping_returns_empty_result(connect: Connect) -> None: """A client ping is answered with an empty result, even by a server with no handlers.""" - server = Server("silent") + server: Server[None] = Server("silent") async with connect(server) as client: result = await client.send_ping() @@ -32,21 +42,18 @@ async def test_server_ping_returns_empty_result(connect: Connect) -> None: The tool returns the type of the ping response, proving the round trip completed inside the handler before the tool result was produced. """ + server: Server[None] = Server("pinger") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="ping_back", description="Ping the client.", inputSchema={"type": "object"})] - ) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ping_back", description="Ping the client.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ping_back" - pong = await ctx.session.send_ping() + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "ping_back" + pong = await request_ctx.get().session.send_ping() return CallToolResult(content=[TextContent(type="text", text=type(pong).__name__)]) - server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("ping_back", {}) From e55b40ecf50932e576c1e25592bf9160df962ef1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 14:08:45 +0000 Subject: [PATCH 07/19] =?UTF-8?q?backport:=20harness=20H5=20=E2=80=94=20bu?= =?UTF-8?q?ild=5Fstreamable=5Fhttp=5Fapp=20auth=20branch=20+=20mounted=5Fa?= =?UTF-8?q?pp=20body?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors FastMCP.streamable_http_app()'s auth gating exactly (verifier derivation, middleware, AS routes, RequireAuthMiddleware wrap, PRM routes); mounted_app now calls the public builder. _connect.py is feature-complete; gate still 6/6. --- tests/interaction/_connect.py | 75 +++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 13 deletions(-) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 43057c10f8..0bcf6fb416 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -16,6 +16,8 @@ import httpx from httpx_sse import ServerSentEvent, aconnect_sse from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route @@ -24,7 +26,10 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier +from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.server import StreamableHTTPASGIApp @@ -139,6 +144,9 @@ def build_streamable_http_app( `/mcp` is mounted via `Route(path, endpoint=)` with no `methods=`, exactly as FastMCP does — Starlette treats a class-instance endpoint as raw ASGI and matches all verbs, which is what the transport requires. + + Unlike `FastMCP.__init__`, this does not enforce `auth_server_provider` XOR + `token_verifier`; the AS-handler tests pass both. """ manager = StreamableHTTPSessionManager( app=_lowlevel(server), @@ -150,14 +158,55 @@ def build_streamable_http_app( ) asgi = StreamableHTTPASGIApp(manager) + # FastMCP derives a verifier from the provider at construction time when no explicit verifier + # is given (mcp/server/fastmcp/server.py:230); the harness has no construction step, so the + # same derivation runs here so the gating below sees the same verifier FastMCP would. + verifier = token_verifier + if auth_server_provider is not None and token_verifier is None: + verifier = ProviderTokenVerifier(auth_server_provider) + routes: list[Route] = [] - # Auth routing (middleware, AS routes, RequireAuthMiddleware wrap, PRM routes) is added in - # H5; until then `auth` / `token_verifier` / `auth_server_provider` are accepted but ignored - # so callers that don't pass them work today. - assert auth is None and token_verifier is None and auth_server_provider is None, "auth branch lands in H5" - routes.append(Route("/mcp", endpoint=asgi)) + middleware: list[Middleware] = [] + required_scopes: list[str] = [] + + if auth is not None: + required_scopes = auth.required_scopes or [] + if verifier is not None: + middleware = [ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(verifier)), + Middleware(AuthContextMiddleware), + ] + if auth_server_provider is not None: + routes.extend( + create_auth_routes( + provider=auth_server_provider, + issuer_url=auth.issuer_url, + service_documentation_url=auth.service_documentation_url, + client_registration_options=auth.client_registration_options, + revocation_options=auth.revocation_options, + ) + ) + + if verifier is not None: + resource_metadata_url = ( + build_resource_metadata_url(auth.resource_server_url) + if auth is not None and auth.resource_server_url + else None + ) + routes.append(Route("/mcp", endpoint=RequireAuthMiddleware(asgi, required_scopes, resource_metadata_url))) + else: + routes.append(Route("/mcp", endpoint=asgi)) + + if auth is not None and auth.resource_server_url: + routes.extend( + create_protected_resource_routes( + resource_url=auth.resource_server_url, + authorization_servers=[auth.issuer_url], + scopes_supported=auth.required_scopes, + ) + ) - return Starlette(routes=routes), manager + return Starlette(routes=routes, middleware=middleware), manager @asynccontextmanager @@ -230,15 +279,15 @@ async def mounted_app( Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test speaks HTTP through the yielded client directly; for client-driven assertions the test wraps - that client in `client_via_http(http)`, which lets several `Client`s share the one mounted - session manager. `on_request` records every outgoing HTTP request before it leaves the + that client in `client_via_http(http)`, which lets several `ClientSession`s share the one + mounted session manager. `on_request` records every outgoing HTTP request before it leaves the yielded client. DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the localhost auto-enable behaviour) to test the protection itself. """ - lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server # noqa: F821 -- body rewritten in H5 - app = lowlevel.streamable_http_app( + app, manager = build_streamable_http_app( + server, stateless_http=stateless_http, json_response=json_response, event_store=event_store, @@ -250,12 +299,12 @@ async def mounted_app( ) event_hooks = {"request": [on_request]} if on_request is not None else None async with ( - server.session_manager.run(), + manager.run(), httpx.AsyncClient( transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers ) as http_client, ): - yield http_client, server.session_manager + yield http_client, manager @asynccontextmanager From 7d9881d7c8aa69b403368ef89b87d9fc7901a42f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 14:13:21 +0000 Subject: [PATCH 08/19] =?UTF-8?q?backport:=20harness=20H-auth-1=20?= =?UTF-8?q?=E2=80=94=20connect=5Fwith=5Foauth=20body=20(hand-assembled=20A?= =?UTF-8?q?S+RS=20app,=20ClientSession=20yield);=20auth=20smoke=20passes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/interaction/auth/_harness.py | 95 ++++++++++++++++++++++------- tests/interaction/auth/test_flow.py | 26 +++++--- 2 files changed, 90 insertions(+), 31 deletions(-) diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index d013364f33..443f2f3f9d 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -1,12 +1,13 @@ """In-process harness for the auth interaction tests. Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the -bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., -token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge -on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the -authorize redirect headlessly by GETing the URL through the same bridge and parsing the code -from the 302 `Location`. The whole authorization-code flow runs in one event loop with no -sockets, no threads, and no real time. +bearer-gated MCP endpoint on one Starlette app assembled from the same public pieces +`FastMCP.streamable_http_app()` uses (`StreamableHTTPSessionManager`, `create_auth_routes`, +`BearerAuthBackend`, `RequireAuthMiddleware`, `create_protected_resource_routes`), drives +that app through the streaming bridge on a single `httpx.AsyncClient` carrying +`auth=OAuthClientProvider(...)`, and completes the authorize redirect headlessly by GETing the +URL through the same bridge and parsing the code from the 302 `Location`. The whole +authorization-code flow runs in one event loop with no sockets, no threads, and no real time. """ import json @@ -18,14 +19,23 @@ import httpx from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Route from starlette.types import ASGIApp, Receive, Scope, Send from mcp.client.auth import OAuthClientProvider -from mcp.client.client import Client +from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client from mcp.server import Server +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.server.fastmcp.server import StreamableHTTPASGIApp +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider @@ -385,7 +395,7 @@ async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: @asynccontextmanager async def connect_with_oauth( - server: Server, + server: Server[Any], *, provider: InMemoryAuthorizationServerProvider, settings: AuthSettings | None = None, @@ -397,12 +407,19 @@ async def connect_with_oauth( verify_tokens: bool = True, app_shim: Callable[[ASGIApp], ASGIApp] | None = None, on_request: Callable[[httpx.Request], None] | None = None, -) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: - """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. +) -> AsyncIterator[tuple[ClientSession, HeadlessOAuth]]: + """Connect a `ClientSession` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. - Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the - SDK put on the authorize request. `on_request` records every HTTP request the underlying - `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + Yields the connected, initialized `ClientSession` and the `HeadlessOAuth` whose + `authorize_url` records what the SDK put on the authorize request. `on_request` records + every HTTP request the underlying `httpx.AsyncClient` issues, including those yielded from + inside the auth flow. + + The Starlette app is assembled from the same public pieces `FastMCP.streamable_http_app()` + uses, so behaviour matches what a v1 user would get from a `FastMCP` configured with + `auth_server_provider=` — except that hand-assembly lets `verify_tokens=False` mount `/mcp` + ungated while still mounting the authorization-server and PRM routes (FastMCP's constructor + auto-derives a token verifier from the provider, so it has no ungated combination). `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without @@ -433,12 +450,44 @@ async def connect_with_oauth( ) ) - app: ASGIApp = server.streamable_http_app( - auth=settings, - token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, - auth_server_provider=provider, - transport_security=NO_DNS_REBINDING_PROTECTION, + manager = StreamableHTTPSessionManager(app=server, security_settings=NO_DNS_REBINDING_PROTECTION) + asgi = StreamableHTTPASGIApp(manager) + + routes: list[Route] = list( + create_auth_routes( + provider=provider, + issuer_url=settings.issuer_url, + service_documentation_url=settings.service_documentation_url, + client_registration_options=settings.client_registration_options, + revocation_options=settings.revocation_options, + ) + ) + middleware: list[Middleware] = [] + required_scopes = settings.required_scopes or [] + resource_metadata_url = ( + build_resource_metadata_url(settings.resource_server_url) if settings.resource_server_url else None ) + + if verify_tokens: + token_verifier = ProviderTokenVerifier(provider) + middleware = [ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)), + Middleware(AuthContextMiddleware), + ] + routes.append(Route("/mcp", endpoint=RequireAuthMiddleware(asgi, required_scopes, resource_metadata_url))) + else: + routes.append(Route("/mcp", endpoint=asgi)) + + if settings.resource_server_url: + routes.extend( + create_protected_resource_routes( + resource_url=settings.resource_server_url, + authorization_servers=[settings.issuer_url], + scopes_supported=required_scopes, + ) + ) + + app: ASGIApp = Starlette(routes=routes, middleware=middleware) if app_shim is not None: app = app_shim(app) @@ -452,14 +501,16 @@ async def hook(request: httpx.Request) -> None: event_hooks = {"request": [hook]} async with AsyncExitStack() as stack: - await stack.enter_async_context(server.session_manager.run()) + await stack.enter_async_context(manager.run()) http_client = await stack.enter_async_context( httpx.AsyncClient( transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks ) ) headless.bind(http_client) - client = await stack.enter_async_context( - Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + read, write, _get_session_id = await stack.enter_async_context( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) ) - yield client, headless + session = await stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + yield session, headless diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py index 4c041cc112..c933d96ea2 100644 --- a/tests/interaction/auth/test_flow.py +++ b/tests/interaction/auth/test_flow.py @@ -19,7 +19,7 @@ from pydantic import AnyUrl from mcp import types -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.auth.middleware.auth_context import get_access_token from mcp.shared.auth import OAuthClientInformationFull from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool @@ -39,8 +39,15 @@ pytestmark = pytest.mark.anyio -async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="whoami", inputSchema={"type": "object"})]) +def _guarded_server() -> Server[object]: + """Build a lowlevel server exposing a single `whoami` tool, in the v1 decorator style.""" + server: Server[object] = Server("guarded") + + @server.list_tools() + async def _list_tools() -> list[types.Tool]: + return [Tool(name="whoami", inputSchema={"type": "object"})] + + return server @requirement("flow:oauth:authorization-code-roundtrip") @@ -67,7 +74,7 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow requests: list[httpx.Request] = [] provider = InMemoryAuthorizationServerProvider() storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() with anyio.fail_after(5): async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( @@ -121,14 +128,15 @@ async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow @requirement("hosting:auth:authinfo-propagates") async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: """A tool handler reads the request's access token through `get_access_token()`.""" + server = _guarded_server() - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "whoami" + @server.call_tool() + async def _call_tool(name: str, arguments: dict[str, object]) -> CallToolResult: + assert name == "whoami" token = get_access_token() assert token is not None return CallToolResult(content=[TextContent(type="text", text=" ".join(token.scopes))]) - server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) provider = InMemoryAuthorizationServerProvider() with anyio.fail_after(5): @@ -148,7 +156,7 @@ async def test_a_preregistered_client_skips_registration() -> None: requests: list[httpx.Request] = [] provider = InMemoryAuthorizationServerProvider() storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() client_info = OAuthClientInformationFull( client_id="preregistered", @@ -183,7 +191,7 @@ async def test_the_dcr_request_carries_the_client_metadata() -> None: requests: list[httpx.Request] = [] provider = InMemoryAuthorizationServerProvider() storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() client_metadata = oauth_client_metadata() client_metadata.software_id = "interaction-test-suite" From b603626d77ac6460770a311d9cd5b83231a76e2e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 14:19:05 +0000 Subject: [PATCH 09/19] =?UTF-8?q?backport:=20phase-4=20wave=201=20(proof)?= =?UTF-8?q?=20=E2=80=94=2025/25=20pass=20across=20lowlevel/mcpserver/trans?= =?UTF-8?q?ports/auth?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - transports/test_bridge (4, no edits), test_stdio (2 + _stdio_server rewrite) - lowlevel/test_completion (5), test_logging (3), test_tools b1/3 (5) - mcpserver/test_completion (1) - auth/test_flow (5) - _requirements.py: tools:call:unknown-name + protocol:error:internal-error divergences updated for v1 --- tests/interaction/_requirements.py | 15 +- tests/interaction/auth/test_flow.py | 8 +- tests/interaction/lowlevel/test_completion.py | 92 ++++++----- tests/interaction/lowlevel/test_logging.py | 57 ++++--- tests/interaction/lowlevel/test_tools.py | 149 +++++++++--------- .../interaction/mcpserver/test_completion.py | 20 ++- tests/interaction/transports/_stdio_server.py | 63 ++++---- tests/interaction/transports/test_stdio.py | 48 +++--- 8 files changed, 241 insertions(+), 211 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 109b30fc77..92e024483e 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -349,8 +349,11 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " - "leaks str(exc) as the error message." + "For tools/call the lowlevel @server.call_tool() decorator wraps the handler in a broad " + "try/except that converts every Exception to CallToolResult(isError=True, " + "content=[TextContent(text=str(exc))]), so the dispatcher's JSON-RPC error path is never " + "reached for tool calls and the test pins the isError=True result. For other request " + "handlers the dispatcher returns code 0 (not -32603) with str(exc) as the message." ), ), ), @@ -559,6 +562,14 @@ def __post_init__(self) -> None: "tools:call:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The lowlevel @server.call_tool() decorator catches every handler exception (including " + "McpError) and converts it to CallToolResult(isError=True, content=[TextContent(text=str(exc))]), " + "so a handler cannot produce a protocol-level JSON-RPC error for tools/call; the test pins " + "the isError=True result instead." + ), + ), ), "tools:capability:declared": Requirement( source=f"{SPEC_BASE_URL}/server/tools#capabilities", diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py index c933d96ea2..6504bc22c0 100644 --- a/tests/interaction/auth/test_flow.py +++ b/tests/interaction/auth/test_flow.py @@ -23,7 +23,7 @@ from mcp.server.auth.middleware.auth_context import get_access_token from mcp.shared.auth import OAuthClientInformationFull from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool -from tests.interaction._connect import BASE_URL +from tests.interaction._connect import BASE_URL, build_streamable_http_app from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( REDIRECT_URI, @@ -229,11 +229,11 @@ async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_w own routing; provided here so the discovery tests can rely on the shim without each adding their own contract test. """ - server = Server("bare") + server: Server[object] = Server("bare") provider = InMemoryAuthorizationServerProvider() - real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + real_app, manager = build_streamable_http_app(server, auth=auth_settings(), auth_server_provider=provider) app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) - async with server.session_manager.run(): + async with manager.run(): async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: served = await http.get("/override") assert served.status_code == 200 diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index 42b059f742..cd4058b370 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -1,15 +1,17 @@ -"""Completion interactions against the low-level Server, driven through the public Client API.""" +"""Completion interactions against the low-level Server, driven through the public client API.""" import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.server import Server, ServerRequestContext +from mcp import McpError +from mcp.server.lowlevel import Server from mcp.types import ( INVALID_PARAMS, METHOD_NOT_FOUND, CompleteResult, Completion, + CompletionArgument, + CompletionContext, ErrorData, PromptReference, ResourceTemplateReference, @@ -27,16 +29,20 @@ async def test_complete_prompt_argument(connect: Connect) -> None: The returned values are filtered by the argument's value, proving the value reached the handler. """ - - async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: - assert isinstance(params.ref, PromptReference) - assert params.ref.name == "code_review" - assert params.argument.name == "language" + server = Server("completer") + + @server.completion() + async def completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + assert isinstance(ref, PromptReference) + assert ref.name == "code_review" + assert argument.name == "language" candidates = ["python", "pytorch", "ruby"] - matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] - return CompleteResult(completion=Completion(values=matches, total=len(matches), hasMore=False)) - - server = Server("completer", on_completion=completion) + matches = [candidate for candidate in candidates if candidate.startswith(argument.value)] + return Completion(values=matches, total=len(matches), hasMore=False) async with connect(server) as client: result = await client.complete( @@ -51,14 +57,18 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("completion:resource-template-arg") async def test_complete_resource_template_variable(connect: Connect) -> None: """Completing a URI template variable delivers the template URI and variable name to the handler.""" - - async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: - assert isinstance(params.ref, ResourceTemplateReference) - assert params.ref.uri == "github://repos/{owner}/{repo}" - assert params.argument.name == "owner" - return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) - - server = Server("completer", on_completion=completion) + server = Server("completer") + + @server.completion() + async def completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + assert isinstance(ref, ResourceTemplateReference) + assert ref.uri == "github://repos/{owner}/{repo}" + assert argument.name == "owner" + return Completion(values=[f"{argument.value}contextprotocol"]) async with connect(server) as client: result = await client.complete( @@ -75,14 +85,18 @@ async def test_complete_receives_context_arguments(connect: Connect) -> None: The returned value is derived from the context, proving it arrived. """ - - async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: - assert params.argument.name == "repo" - assert params.context is not None - assert params.context.arguments is not None - return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) - - server = Server("completer", on_completion=completion) + server = Server("completer") + + @server.completion() + async def completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + assert argument.name == "repo" + assert context is not None + assert context.arguments is not None + return Completion(values=[f"{context.arguments['owner']}/python-sdk"]) async with connect(server) as client: result = await client.complete( @@ -102,15 +116,19 @@ async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended way to do it. """ + server = Server("completer") - async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: - assert isinstance(params.ref, PromptReference) - raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") - - server = Server("completer", on_completion=completion) + @server.completion() + async def completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + assert isinstance(ref, PromptReference) + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown prompt: {ref.name!r}")) async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.complete(PromptReference(type="ref/prompt", name="ghost"), argument={"name": "x", "value": ""}) assert exc_info.value.error.code == INVALID_PARAMS @@ -123,9 +141,11 @@ async def test_complete_without_handler_is_method_not_found(connect: Connect) -> server = Server("incomplete") async with connect(server) as client: - assert client.initialize_result.capabilities.completions is None + capabilities = client.get_server_capabilities() + assert capabilities is not None + assert capabilities.completions is None - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.complete( PromptReference(type="ref/prompt", name="anything"), argument={"name": "topic", "value": ""} ) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index 070ee49c82..13ed6d3a87 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -9,11 +9,13 @@ assert after the request completes on every transport leg -- no events, no waiting. """ +from typing import Any + import pytest from inline_snapshot import snapshot from mcp import types -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -35,12 +37,11 @@ @requirement("logging:set-level") async def test_set_logging_level_reaches_handler(connect: Connect) -> None: """The level requested by the client is delivered to the server's handler verbatim.""" + server = Server("logger") - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: - assert params.level == "warning" - return EmptyResult() - - server = Server("logger", on_set_logging_level=set_logging_level) + @server.set_logging_level() + async def set_logging_level(level: types.LoggingLevel) -> None: + assert level == "warning" async with connect(server) as client: result = await client.set_logging_level("warning") @@ -61,27 +62,29 @@ async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="chatty", inputSchema={"type": "object"})]) + server = Server("logger") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="chatty", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "chatty" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "chatty" + ctx = server.request_context await ctx.session.send_log_message( level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id ) await ctx.session.send_log_message( level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id ) - return CallToolResult(content=[TextContent(type="text", text="done")]) + return [TextContent(type="text", text="done")] - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + @server.set_logging_level() + async def set_logging_level(level: types.LoggingLevel) -> None: """Registered so the logging capability is advertised; the client never sets a level.""" raise NotImplementedError - server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) - async with connect(server, logging_callback=collect) as client: result = await client.call_tool("chatty", {}) @@ -102,25 +105,27 @@ async def test_log_messages_at_every_severity_level(connect: Connect) -> None: async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="siren", inputSchema={"type": "object"})]) + server = Server("logger") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="siren", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "siren" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "siren" + ctx = server.request_context for level in ALL_LEVELS: await ctx.session.send_log_message( level=level, data=f"a {level} message", related_request_id=ctx.request_id ) - return CallToolResult(content=[TextContent(type="text", text="logged")]) + return [TextContent(type="text", text="logged")] - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + @server.set_logging_level() + async def set_logging_level(level: types.LoggingLevel) -> None: """Registered so the logging capability is advertised; the client never sets a level.""" raise NotImplementedError - server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) - async with connect(server, logging_callback=collect) as client: await client.call_tool("siren", {}) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 25eee750cd..1ac8aa5812 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -1,11 +1,13 @@ """Tool interactions against the low-level Server, driven through the public Client API.""" +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.server.lowlevel import Server from mcp.types import ( INVALID_PARAMS, AudioContent, @@ -30,22 +32,16 @@ @requirement("tools:call:content:text") async def test_call_tool_returns_text_content(connect: Connect) -> None: """Arguments reach the tool handler; its content comes back as the call result.""" + server = Server("adder") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="add", description="Add two integers.", inputSchema={"type": "object"})] - ) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="add", description="Add two integers.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "add" - assert params.arguments is not None - return CallToolResult( - content=[TextContent(type="text", text=str(params.arguments["a"] + params.arguments["b"]))] - ) - - server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "add" + return CallToolResult(content=[TextContent(type="text", text=str(arguments["a"] + arguments["b"]))]) async with connect(server) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) @@ -55,18 +51,18 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("tools:call:is-error") async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: - """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + """A tool reporting its own failure with isError=True reaches the client as a result, not an exception. Tool execution errors are part of the result so the caller (typically a model) can see them; only protocol-level failures become JSON-RPC errors. """ + server = Server("errors") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "flux" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "flux" return CallToolResult(content=[TextContent(type="text", text="the flux capacitor is offline")], isError=True) - server = Server("errors", on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("flux", {}) @@ -77,67 +73,68 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("tools:call:unknown-name") async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: - """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + """A handler that rejects an unrecognised tool name with McpError is swallowed into an isError result. - The error's code, message, and data chosen by the handler reach the client verbatim. + On v1 the lowlevel `@server.call_tool()` decorator catches every handler exception (including + `McpError`) and converts it to `CallToolResult(isError=True, content=[TextContent(text=str(exc))])`, + so the handler cannot produce a protocol-level JSON-RPC error for this method. See the + divergence note on the requirement. """ + server = Server("errors") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) - - server = Server("errors", on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown tool: {name}", data={"requested": name})) async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.call_tool("nope", {}) + result = await client.call_tool("nope", {}) - assert exc_info.value.error == snapshot( - ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + assert result == snapshot( + CallToolResult(content=[TextContent(type="text", text="Unknown tool: nope")], isError=True) ) @requirement("protocol:error:internal-error") async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: - """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + """An uncaught exception in a tool handler is swallowed into an isError=True result. - The low-level server reports it with code 0 and the exception text as the message; see the - divergence note on the requirement. + On v1 the lowlevel `@server.call_tool()` decorator wraps the handler in a broad try/except + that converts every `Exception` to `CallToolResult(isError=True, content=[TextContent(text=str(exc))])`, + so the dispatcher's JSON-RPC error path is never reached for tool calls. See the divergence + note on the requirement. """ + server = Server("errors") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "explode" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "explode" raise ValueError("boom") - server = Server("errors", on_call_tool=call_tool) - async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.call_tool("explode", {}) + result = await client.call_tool("explode", {}) - assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="boom")], isError=True)) @requirement("tools:list:basic") async def test_list_tools_returns_registered_tools(connect: Connect) -> None: """The tools advertised by the server's list handler arrive at the client unchanged.""" - - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="add", - description="Add two integers.", - inputSchema={ - "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, - "required": ["a", "b"], - }, - ), - Tool(name="reset", description="Reset the calculator.", inputSchema={"type": "object"}), - ] - ) - - server = Server("calculator", on_list_tools=list_tools) + server = Server("calculator") + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="add", + description="Add two integers.", + inputSchema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", inputSchema={"type": "object"}), + ] async with connect(server) as client: result = await client.list_tools() @@ -188,10 +185,10 @@ async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Con "additionalProperties": False, } - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult(tools=[Tool(name="typed", inputSchema=schema)]) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "typed" assert params.arguments == {"count": 3, "options": {"verbose": True}} return CallToolResult(content=[TextContent(type="text", text="ok")]) @@ -221,7 +218,7 @@ async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: _meta={"example.com/source": "interaction-suite"}, ) - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult(tools=[tool]) server = Server("annotated", on_list_tools=list_tools) @@ -259,10 +256,10 @@ async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: snapshot pins the exact bytes the client receives. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult(tools=[Tool(name="render", inputSchema={"type": "object"})]) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "render" return CallToolResult( content=[ @@ -306,10 +303,10 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_call_tool_structured_content(connect: Connect) -> None: """A tool result carrying structured content alongside content delivers both to the client.""" - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult(tools=[Tool(name="sum", inputSchema={"type": "object"})]) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "sum" return CallToolResult(content=[TextContent(type="text", text="the sum is 5")], structuredContent={"sum": 5}) @@ -336,10 +333,10 @@ async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> release = anyio.Event() results: dict[str, CallToolResult] = {} - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "echo" assert params.arguments is not None tag = params.arguments["tag"] @@ -381,7 +378,7 @@ async def test_call_tool_structured_content_violating_output_schema_is_rejected_ reaches the caller: the client validates it against the schema cached from tools/list and raises. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult( tools=[ Tool( @@ -396,7 +393,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ] ) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "forecast" return CallToolResult( content=[TextContent(type="text", text="warm")], structuredContent={"temperature": "warm"} @@ -421,7 +418,7 @@ async def test_is_error_result_bypasses_client_output_schema_validation(connect: isError flag and not an empty cache. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult( tools=[ Tool( @@ -436,7 +433,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ] ) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "forecast" return CallToolResult( content=[TextContent(type="text", text="boom")], structuredContent={"temperature": "warm"}, isError=True @@ -462,7 +459,7 @@ async def test_declared_output_schema_with_no_structured_content_is_rejected_by_ The error is the SDK's own message, so the full text is snapshotted. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body return ListToolsResult( tools=[ Tool( @@ -473,7 +470,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ] ) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "forecast" return CallToolResult(content=[TextContent(type="text", text="warm")]) @@ -498,7 +495,7 @@ async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools """ list_calls: list[str] = [] - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body list_calls.append("called") return ListToolsResult( tools=[ @@ -510,7 +507,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ] ) - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body assert params.name == "forecast" return CallToolResult(content=[TextContent(type="text", text="21 C")], structuredContent={"temperature": 21}) diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py index 7761066e94..1499fd9b4f 100644 --- a/tests/interaction/mcpserver/test_completion.py +++ b/tests/interaction/mcpserver/test_completion.py @@ -1,8 +1,8 @@ -"""Completion behaviour against MCPServer, driven through the public Client API.""" +"""Completion behaviour against FastMCP, driven through the public client API.""" import pytest -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP from mcp.types import ( Completion, CompletionArgument, @@ -19,8 +19,8 @@ @requirement("mcpserver:completion:capability-auto") async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: - """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" - with_handler = MCPServer("completer") + """A FastMCP with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = FastMCP("completer") @with_handler.completion() async def complete( @@ -32,7 +32,11 @@ async def complete( raise NotImplementedError async with connect(with_handler) as client: - assert client.initialize_result.capabilities.completions == CompletionsCapability() - - async with connect(MCPServer("plain")) as client: - assert client.initialize_result.capabilities.completions is None + capabilities = client.get_server_capabilities() + assert capabilities is not None + assert capabilities.completions == CompletionsCapability() + + async with connect(FastMCP("plain")) as client: + capabilities = client.get_server_capabilities() + assert capabilities is not None + assert capabilities.completions is None diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py index 633b43f340..f0d332f3d9 100644 --- a/tests/interaction/transports/_stdio_server.py +++ b/tests/interaction/transports/_stdio_server.py @@ -7,48 +7,39 @@ """ import sys +from typing import Any import anyio -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.stdio import stdio_server -from mcp.types import ( - CallToolRequestParams, - CallToolResult, - EmptyResult, - ListToolsResult, - PaginatedRequestParams, - SetLevelRequestParams, - TextContent, - Tool, -) - - -async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="echo", - inputSchema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, - ) - ] - ) - - -async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - assert params.name == "echo" - assert params.arguments is not None - text = params.arguments["text"] - await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") - return CallToolResult(content=[TextContent(type="text", text=text)]) - - -async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: - """Registered so the logging capability is advertised; the client never sets a level.""" - raise NotImplementedError +from mcp.types import LoggingLevel, TextContent, Tool + +server = Server("stdio-echo") + + +@server.list_tools() +async def list_tools() -> list[Tool]: + return [ + Tool( + name="echo", + inputSchema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] -server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) +@server.call_tool() +async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + assert name == "echo" + text = arguments["text"] + await server.request_context.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return [TextContent(type="text", text=text)] + + +@server.set_logging_level() +async def set_logging_level(level: LoggingLevel) -> None: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError async def main() -> None: diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index d805e64933..279fd40090 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -26,19 +26,19 @@ import pytest from inline_snapshot import snapshot -from mcp.client.client import Client +from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage from mcp.types import ( CallToolResult, + JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, LoggingMessageNotificationParams, TextContent, ) -from mcp.types.jsonrpc import jsonrpc_message_adapter from tests.interaction._connect import initialize_body from tests.interaction._requirements import requirement from tests.interaction.transports import _stdio_server @@ -52,7 +52,7 @@ @requirement("transport:stdio:clean-shutdown") @requirement("transport:stdio:stderr-passthrough") async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: - """A Client connected over stdio initializes, calls a tool with arguments, receives the + """A ClientSession connected over stdio initializes, calls a tool with arguments, receives the server's log notification before the call returns, and the server exits when the transport closes its stdin.""" received: list[LoggingMessageNotificationParams] = [] @@ -61,22 +61,24 @@ async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) with tempfile.TemporaryFile(mode="w+") as errlog: - transport = stdio_client( - StdioServerParameters( - command=sys.executable, - args=["-m", _stdio_server.__name__], - cwd=str(_REPO_ROOT), - # stdio_client deliberately filters the inherited environment to a safe minimum, - # which drops the variables coverage.py's subprocess support uses; pass them through - # so the server module is measured. Empty when not running under coverage. - env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, - ), - errlog=errlog, - ) - with anyio.fail_after(10): - async with Client(transport, logging_callback=collect) as client: - assert client.initialize_result.server_info.name == "stdio-echo" + async with ( + stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) as (read, write), + ClientSession(read, write, logging_callback=collect) as client, + ): + initialize_result = await client.initialize() + assert initialize_result.serverInfo.name == "stdio-echo" result = await client.call_tool("echo", {"text": "across\nprocesses"}) errlog.seek(0) @@ -121,21 +123,21 @@ async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: ): received = await read_stream.receive() assert isinstance(received, SessionMessage) - assert isinstance(received.message, JSONRPCRequest) - assert received.message.method == "initialize" + assert isinstance(received.message.root, JSONRPCRequest) + assert received.message.root.method == "initialize" response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) notification = JSONRPCNotification( jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} ) - await write_stream.send(SessionMessage(response)) - await write_stream.send(SessionMessage(notification)) + await write_stream.send(SessionMessage(JSONRPCMessage(response))) + await write_stream.send(SessionMessage(JSONRPCMessage(notification))) output = captured.getvalue() assert output.endswith("\n") lines = output.removesuffix("\n").split("\n") assert len(lines) == 2 - messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + messages = [JSONRPCMessage.model_validate_json(line).root for line in lines] assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would # break the one-message-per-line framing. From 1211e79b1b18d184fa0dcc6ae0409d19728cb329 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 14:34:15 +0000 Subject: [PATCH 10/19] =?UTF-8?q?backport:=20phase-4=20wave=202=20+=20S1?= =?UTF-8?q?=20=E2=80=94=2051=20pass=20/=201=20deferred?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - lowlevel/{cancellation,flows,meta,progress,prompts,tools} complete - mcpserver/test_prompts complete - transports/{sse,streamable_http} complete (1 deferred: stateless-mode server-initiated guard absent in v1) - _requirements.py: 3 divergence updates (tools:call:unknown-name already in w1; client:output-schema:auto-list) --- tests/interaction/_requirements.py | 8 +- .../interaction/lowlevel/test_cancellation.py | 123 +++++----- tests/interaction/lowlevel/test_flows.py | 110 ++++----- tests/interaction/lowlevel/test_meta.py | 49 ++-- tests/interaction/lowlevel/test_progress.py | 144 ++++++------ tests/interaction/lowlevel/test_prompts.py | 88 ++++--- tests/interaction/lowlevel/test_tools.py | 216 +++++++++--------- tests/interaction/mcpserver/test_prompts.py | 28 +-- tests/interaction/transports/test_sse.py | 28 ++- .../transports/test_streamable_http.py | 48 ++-- 10 files changed, 443 insertions(+), 399 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 92e024483e..312cdead56 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -646,8 +646,12 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "Design concern rather than spec violation: the implicit request is invisible to the " - "caller, and against a server that registers only on_call_tool a successful call surfaces " - "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + "caller, and against a server that registers only a call_tool handler a successful call " + "surfaces as METHOD_NOT_FOUND from a tools/list the caller never asked for. On v1 the " + "lowlevel server's call_tool decorator also primes its own tool cache by invoking the " + "registered list_tools handler internally on a miss, so the first uncached call observes " + "the handler running twice (server-side priming + the client's implicit request); the test " + "pins that count." ), ), ), diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 9ba6797e0a..4696baa951 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -1,26 +1,32 @@ -"""Cancellation interactions against the low-level Server, driven through the public Client API. +"""Cancellation interactions against the low-level Server, driven through the public client API. There is no client-side cancellation API: cancelling means sending a CancelledNotification -carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so -these tests capture the id from inside the blocked handler before cancelling. The handler blocks -on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +carrying the request id, which only the server-side handler can observe (via +`server.request_context.request_id`), so these tests capture the id from inside the blocked +handler before cancelling. The handler blocks on an Event rather than a sleep, and every wait +is bounded by `anyio.fail_after`. """ +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.client import ClientSession -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( CallToolResult, + ClientNotification, + ClientRequest, EmptyResult, ErrorData, Implementation, InitializeResult, + JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -49,8 +55,12 @@ async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: request_ids: list[types.RequestId] = [] errors: list[ErrorData] = [] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "block" + server: Server[Any] = Server("blocker") + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "block" + ctx = server.request_context assert ctx.request_id is not None request_ids.append(ctx.request_id) started.set() @@ -61,22 +71,22 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara raise raise NotImplementedError # unreachable: the wait above never completes normally - server = Server("blocker", on_call_tool=call_tool) - async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: async def call_and_capture_error() -> None: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.call_tool("block", {}) errors.append(exc_info.value.error) task_group.start_soon(call_and_capture_error) await started.wait() - await client.session.send_notification( - types.CancelledNotification( - params=types.CancelledNotificationParams(requestId=request_ids[0], reason="user aborted") + await client.send_notification( + ClientNotification( + types.CancelledNotification( + params=types.CancelledNotificationParams(requestId=request_ids[0], reason="user aborted") + ) ) ) @@ -91,39 +101,40 @@ async def test_session_serves_requests_after_cancellation(connect: Connect) -> N started = anyio.Event() request_ids: list[types.RequestId] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool(name="block", inputSchema={"type": "object"}), - types.Tool(name="echo", inputSchema={"type": "object"}), - ] - ) + server: Server[Any] = Server("blocker") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool(name="block", inputSchema={"type": "object"}), + types.Tool(name="echo", inputSchema={"type": "object"}), + ] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - if params.name == "echo": + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + if name == "echo": return CallToolResult(content=[TextContent(type="text", text="still alive")]) + ctx = server.request_context assert ctx.request_id is not None request_ids.append(ctx.request_id) started.set() await anyio.Event().wait() # blocks until cancelled raise NotImplementedError # unreachable - server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: async def call_and_swallow_cancellation_error() -> None: - with pytest.raises(MCPError): + with pytest.raises(McpError): await client.call_tool("block", {}) task_group.start_soon(call_and_swallow_cancellation_error) await started.wait() - await client.session.send_notification( - types.CancelledNotification(params=types.CancelledNotificationParams(requestId=request_ids[0])) + await client.send_notification( + ClientNotification( + types.CancelledNotification(params=types.CancelledNotificationParams(requestId=request_ids[0])) + ) ) result = await client.call_tool("echo", {}) @@ -135,20 +146,20 @@ async def call_and_swallow_cancellation_error() -> None: async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: """A cancellation referencing a request id that is not in flight is ignored without error.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="echo", inputSchema={"type": "object"})]) + server: Server[Any] = Server("calm") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "echo" - return CallToolResult(content=[TextContent(type="text", text="unbothered")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="echo", inputSchema={"type": "object"})] - server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "echo" + return CallToolResult(content=[TextContent(type="text", text="unbothered")]) async with connect(server) as client: - await client.session.send_notification( - types.CancelledNotification(params=types.CancelledNotificationParams(requestId=9999)) + await client.send_notification( + ClientNotification(types.CancelledNotification(params=types.CancelledNotificationParams(requestId=9999))) ) result = await client.call_tool("echo", {}) @@ -176,21 +187,23 @@ async def scripted_server(streams: MessageStream) -> None: def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: return SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) ) ) init = await server_read.receive() assert isinstance(init, SessionMessage) - assert isinstance(init.message, JSONRPCRequest) - assert init.message.method == "initialize" + assert isinstance(init.message.root, JSONRPCRequest) + assert init.message.root.method == "initialize" await server_write.send( respond( - init.message.id, + init.message.root.id, InitializeResult( protocolVersion="2025-11-25", capabilities=ServerCapabilities(), @@ -201,16 +214,16 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage initialized = await server_read.receive() assert isinstance(initialized, SessionMessage) - assert isinstance(initialized.message, JSONRPCNotification) - assert initialized.message.method == "notifications/initialized" + assert isinstance(initialized.message.root, JSONRPCNotification) + assert initialized.message.root.method == "notifications/initialized" ping = await server_read.receive() assert isinstance(ping, SessionMessage) - assert isinstance(ping.message, JSONRPCRequest) - assert ping.message.method == "ping" + assert isinstance(ping.message.root, JSONRPCRequest) + assert ping.message.root.method == "ping" # First answer with a fabricated id that matches nothing in flight, then the real id. await server_write.send(respond(9999, EmptyResult())) - await server_write.send(respond(ping.message.id, EmptyResult())) + await server_write.send(respond(ping.message.root.id, EmptyResult())) incoming: list[IncomingMessage] = [] @@ -225,7 +238,7 @@ async def message_handler(message: IncomingMessage) -> None: task_group.start_soon(scripted_server, server_streams) with anyio.fail_after(5): await session.initialize() - pong = await session.send_request(PingRequest(), EmptyResult) + pong = await session.send_request(ClientRequest(PingRequest()), EmptyResult) assert pong == snapshot(EmptyResult()) assert len(incoming) == 1 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py index e3397e2f4b..9aa3d30989 100644 --- a/tests/interaction/lowlevel/test_flows.py +++ b/tests/interaction/lowlevel/test_flows.py @@ -6,16 +6,19 @@ individual features are pinned by their own tests; these prove they compose. """ -from collections.abc import Awaitable, Callable +from typing import Any import anyio import pytest from inline_snapshot import snapshot +from pydantic import AnyUrl -from mcp import MCPError, UrlElicitationRequiredError, types -from mcp.client import ClientRequestContext -from mcp.server import Server, ServerRequestContext +from mcp import McpError, UrlElicitationRequiredError, types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext from mcp.types import ( URL_ELICITATION_REQUIRED, CallToolResult, @@ -23,10 +26,11 @@ ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, - EmptyResult, - ListToolsResult, + ErrorData, + LoggingLevel, ReadResourceResult, ResourceLink, + ServerNotification, TextContent, TextResourceContents, Tool, @@ -37,18 +41,13 @@ pytestmark = pytest.mark.anyio -ListToolsHandler = Callable[ - [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] -] +def _register_list_tools(server: Server, *names: str) -> None: + """Register a list_tools handler advertising the named tools, so call_tool's cache lookup succeeds.""" -def _list_tools(*names: str) -> ListToolsHandler: - """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" - - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name=name, inputSchema={"type": "object"}) for name in names]) - - return list_tools + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name=name, inputSchema={"type": "object"}) for name in names] @requirement("flow:tool-result:resource-link-follow") @@ -58,18 +57,18 @@ async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(conn Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the link's URI, (4) the read result carries the linked contents. """ + server = Server("linker") + _register_list_tools(server, "generate") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "generate" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "generate" return CallToolResult(content=[ResourceLink(type="resource_link", uri="file:///report.txt", name="report")]) - async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: - assert str(params.uri) == "file:///report.txt" - return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) - - server = Server( - "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource - ) + @server.read_resource() + async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: + assert str(uri) == "file:///report.txt" + return [ReadResourceContents(content="generated", mime_type="text/plain")] async with connect(server) as client: called = await client.call_tool("generate", {}) @@ -81,7 +80,9 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq CallToolResult(content=[ResourceLink(type="resource_link", name="report", uri="file:///report.txt")]) ) assert read == snapshot( - ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ReadResourceResult( + contents=[TextResourceContents(uri="file:///report.txt", mimeType="text/plain", text="generated")] + ) ) @@ -99,13 +100,18 @@ async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forwa received: list[ElicitRequestFormParams] = [] answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "onboard" - first = await ctx.session.elicit_form( + server = Server("onboarder") + _register_list_tools(server, "onboard") + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "onboard" + session = server.request_context.session + first = await session.elicit_form( "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} ) assert first.action == "accept" and first.content is not None - second = await ctx.session.elicit_form( + second = await session.elicit_form( f"Step 2: confirm age for {first.content['name']}.", {"type": "object", "properties": {"age": {"type": "integer"}}}, ) @@ -114,9 +120,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara content=[TextContent(type="text", text=f"{first.content['name']} is {second.content['age']}")] ) - server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) - - async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult | ErrorData: assert isinstance(params, ElicitRequestFormParams) received.append(params) return ElicitResult(action="accept", content=answers[len(received) - 1]) @@ -125,7 +131,7 @@ async def answer(context: ClientRequestContext, params: types.ElicitRequestParam result = await client.call_tool("onboard", {}) assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ada is 37")])) - assert [(p.message, p.requested_schema) for p in received] == snapshot( + assert [(p.message, p.requestedSchema) for p in received] == snapshot( [ ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), @@ -147,6 +153,9 @@ async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_a succeeds. The handler distinguishes the two calls by a closure flag the test flips between them; the test waits on the completion notification with an event so the retry only happens after the announcement has arrived. + + The handler reaches its session via ``server.request_context.session`` and stores it for + out-of-band use — a v1-public pattern for callbacks that fire after the request returns. """ elicitation_id = "auth-001" authorised: list[bool] = [False] @@ -154,13 +163,18 @@ async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_a completed = anyio.Event() notifications: list[ElicitCompleteNotification] = [] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "read_files" - captured.append(ctx.session) + server = Server("gatekeeper") + _register_list_tools(server, "read_files") + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "read_files" + session = server.request_context.session + captured.append(session) if not authorised[0]: # The log line gives the message handler a non-completion notification, so the test's # filtering branch is exercised in both directions and the wait remains specific. - await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + await session.send_log_message(level="warning", data="authorisation required", logger="gate") raise UrlElicitationRequiredError( [ ElicitRequestURLParams( @@ -172,34 +186,28 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) return CallToolResult(content=[TextContent(type="text", text="contents")]) - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + @server.set_logging_level() + async def set_logging_level(level: LoggingLevel) -> None: """Registered so the logging capability is advertised; the client never sets a level.""" raise NotImplementedError - server = Server( - "gatekeeper", - on_list_tools=_list_tools("read_files"), - on_call_tool=call_tool, - on_set_logging_level=set_logging_level, - ) - async def collect(message: IncomingMessage) -> None: - if isinstance(message, ElicitCompleteNotification): - notifications.append(message) + if isinstance(message, ServerNotification) and isinstance(message.root, ElicitCompleteNotification): + notifications.append(message.root) completed.set() async with connect(server, message_handler=collect) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.call_tool("read_files", {}) assert exc_info.value.error.code == URL_ELICITATION_REQUIRED required = UrlElicitationRequiredError.from_error(exc_info.value.error) - assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + assert [e.elicitationId for e in required.elicitations] == [elicitation_id] # The out-of-band interaction completes; the server announces it on the same session. await captured[0].send_elicit_complete(elicitation_id) with anyio.fail_after(5): await completed.wait() - assert notifications[0].params.elicitation_id == elicitation_id + assert notifications[0].params.elicitationId == elicitation_id authorised[0] = True result = await client.call_tool("read_files", {}) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py index 821beeebae..9236ad3d9d 100644 --- a/tests/interaction/lowlevel/test_meta.py +++ b/tests/interaction/lowlevel/test_meta.py @@ -5,11 +5,13 @@ which also proves the SDK injected nothing alongside it. """ +from typing import Any + import pytest from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from mcp.server.lowlevel import Server +from mcp.types import CallToolResult, TextContent from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -19,26 +21,27 @@ @requirement("meta:request-to-handler") async def test_request_meta_reaches_handler(connect: Connect) -> None: """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" - request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} - observed_metas: list[dict[str, object]] = [] + request_meta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, Any]] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="traced", inputSchema={"type": "object"})]) + server = Server("observability") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "traced" - assert ctx.meta is not None - observed_metas.append(dict(ctx.meta)) - return CallToolResult(content=[TextContent(type="text", text="traced")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="traced", inputSchema={"type": "object"})] - server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "traced" + ctx = server.request_context + assert ctx.meta is not None + observed_metas.append(ctx.meta.model_dump(exclude_none=True)) + return [TextContent(type="text", text="traced")] async with connect(server) as client: await client.call_tool("traced", {}, meta=request_meta) - assert observed_metas == [dict(request_meta)] + assert observed_metas == [request_meta] @requirement("meta:result-to-client") @@ -46,16 +49,16 @@ async def test_result_meta_reaches_client(connect: Connect) -> None: """The _meta object a handler attaches to its result is delivered to the client unchanged.""" result_meta = {"example.com/cost": 3} - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="metered", inputSchema={"type": "object"})]) + server = Server("observability") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "metered" - return CallToolResult(content=[TextContent(type="text", text="done")], _meta=result_meta) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="metered", inputSchema={"type": "object"})] - server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "metered" + return CallToolResult(content=[TextContent(type="text", text="done")], _meta=result_meta) async with connect(server) as client: result = await client.call_tool("metered", {}) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 183afa6098..78c8041320 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -9,15 +9,17 @@ server's handler. """ +from typing import Any + import anyio import pytest from inline_snapshot import snapshot from mcp import types -from mcp.server import Server, ServerRequestContext +from mcp.server.lowlevel import Server from mcp.server.session import ServerSession from mcp.shared.session import ProgressFnT -from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from mcp.types import CallToolResult, ProgressNotification, ProgressToken, ServerNotification, TextContent from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -34,15 +36,18 @@ async def test_progress_during_tool_call_reaches_callback_in_order(connect: Conn async def collect(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="download", inputSchema={"type": "object"})]) + server = Server("downloader") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="download", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "download" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "download" + ctx = server.request_context assert ctx.meta is not None - token = ctx.meta.get("progress_token") + token = ctx.meta.progressToken assert token is not None await ctx.session.send_progress_notification( token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) @@ -55,8 +60,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) return CallToolResult(content=[TextContent(type="text", text="downloaded")]) - server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("download", {}, progress_callback=collect) @@ -67,18 +70,18 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:progress:token-injected") async def test_progress_token_visible_to_handler(connect: Connect) -> None: """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + server = Server("introspector") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="inspect", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="inspect", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "inspect" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "inspect" + ctx = server.request_context assert ctx.meta is not None - return CallToolResult(content=[TextContent(type="text", text=str(ctx.meta.get("progress_token")))]) - - server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + return CallToolResult(content=[TextContent(type="text", text=str(ctx.meta.progressToken))]) async def ignore(progress: float, total: float | None, message: str | None) -> None: """A progress callback that is never invoked; the tool only inspects the token.""" @@ -98,18 +101,18 @@ async def test_no_progress_callback_means_no_token(connect: Connect) -> None: The low-level API has no way to report request-scoped progress without a token, so a handler that sees no token has nothing to send progress against. """ + server = Server("introspector") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="inspect", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="inspect", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "inspect" - assert ctx.meta is not None - return CallToolResult(content=[TextContent(type="text", text=str(ctx.meta.get("progress_token")))]) - - server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "inspect" + ctx = server.request_context + token = ctx.meta.progressToken if ctx.meta is not None else None + return CallToolResult(content=[TextContent(type="text", text=str(token))]) async with connect(server) as client: result = await client.call_tool("inspect", {}) @@ -120,23 +123,22 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:progress:client-to-server") async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: """A progress notification sent by the client is delivered to the server's progress handler.""" - received: list[ProgressNotificationParams] = [] + received: list[tuple[str | int, float, float | None, str | None]] = [] delivered = anyio.Event() - async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: - received.append(params) - delivered.set() + server = Server("observer") - server = Server("observer", on_progress=on_progress) + @server.progress_notification() + async def on_progress(token: str | int, progress: float, total: float | None, message: str | None) -> None: + received.append((token, progress, total, message)) + delivered.set() async with connect(server) as client: await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") with anyio.fail_after(5): await delivered.wait() - assert received == snapshot( - [ProgressNotificationParams(progressToken="upload-1", progress=0.5, total=1.0, message="halfway")] - ) + assert received == snapshot([("upload-1", 0.5, 1.0, "halfway")]) @requirement("protocol:progress:token-unique") @@ -157,18 +159,20 @@ async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Conne # turns[n] is set to release the nth emission; each emission releases the next. turns = [anyio.Event() for _ in range(4)] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="report", inputSchema={"type": "object"})]) + server = Server("reporter") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="report", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "report" - assert params.arguments is not None + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "report" + ctx = server.request_context assert ctx.meta is not None - token = ctx.meta.get("progress_token") + token = ctx.meta.progressToken assert token is not None - label = params.arguments["label"] + label = arguments["label"] tokens[label] = token entered[label].set() # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. @@ -186,8 +190,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara turns[second + 1].set() return CallToolResult(content=[TextContent(type="text", text="done")]) - server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) - received_a: list[float] = [] received_b: list[float] = [] @@ -228,22 +230,23 @@ async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback """ captured: list[tuple[ServerSession, ProgressToken]] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="report", inputSchema={"type": "object"})]) + server = Server("reporter") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="report", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "report" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "report" + ctx = server.request_context assert ctx.meta is not None - token = ctx.meta.get("progress_token") + token = ctx.meta.progressToken assert token is not None captured.append((ctx.session, token)) await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) return CallToolResult(content=[TextContent(type="text", text="done")]) - server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) - received: list[float] = [] late_progress_arrived = anyio.Event() @@ -251,7 +254,11 @@ async def collect(progress: float, total: float | None, message: str | None) -> received.append(progress) async def message_handler(message: IncomingMessage) -> None: - if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + if ( + isinstance(message, ServerNotification) + and isinstance(message.root, ProgressNotification) + and message.root.params.progress == 1.0 + ): late_progress_arrived.set() async with connect(server, message_handler=message_handler) as client: @@ -278,23 +285,24 @@ async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: C async def collect(progress: float, total: float | None, message: str | None) -> None: received.append(progress) - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="zigzag", inputSchema={"type": "object"})]) + server = Server("zigzagger") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="zigzag", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "zigzag" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "zigzag" + ctx = server.request_context assert ctx.meta is not None - token = ctx.meta.get("progress_token") + token = ctx.meta.progressToken assert token is not None await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) return CallToolResult(content=[TextContent(type="text", text="done")]) - server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: await client.call_tool("zigzag", {}, progress_callback=collect) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index 50da0e75f2..f1baae0adc 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -3,8 +3,8 @@ import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.server import Server, ServerRequestContext +from mcp import McpError +from mcp.server.lowlevel import Server from mcp.types import ( INVALID_PARAMS, AudioContent, @@ -29,24 +29,22 @@ @requirement("prompts:list:basic") async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: """The prompts returned by the handler reach the client with their argument declarations intact.""" - - async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: - return ListPromptsResult( - prompts=[ - Prompt( - name="code_review", - description="Review a piece of code.", - arguments=[ - PromptArgument(name="code", description="The code to review.", required=True), - PromptArgument(name="style_guide", description="Optional style guide to apply."), - ], - icons=[Icon(src="https://example.com/review.png", mimeType="image/png", sizes=["48x48"])], - ), - Prompt(name="daily_standup"), - ] - ) - - server = Server("prompter", on_list_prompts=list_prompts) + server = Server("prompter") + + @server.list_prompts() + async def list_prompts() -> list[Prompt]: + return [ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mimeType="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] async with connect(server) as client: result = await client.list_prompts() @@ -72,19 +70,19 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest @requirement("prompts:get:with-args") async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + server = Server("prompter") - async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: - assert params.name == "greet" - assert params.arguments is not None + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> GetPromptResult: + assert name == "greet" + assert arguments is not None return GetPromptResult( description="A personalised greeting.", messages=[ - PromptMessage(role="user", content=TextContent(type="text", text=f"Hello, {params.arguments['name']}!")) + PromptMessage(role="user", content=TextContent(type="text", text=f"Hello, {arguments['name']}!")) ], ) - server = Server("prompter", on_get_prompt=get_prompt) - async with connect(server) as client: result = await client.get_prompt("greet", {"name": "Ada"}) @@ -99,9 +97,11 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa @requirement("prompts:get:multi-message") async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + server = Server("prompter") - async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: - assert params.name == "geography_quiz" + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> GetPromptResult: + assert name == "geography_quiz" return GetPromptResult( messages=[ PromptMessage(role="user", content=TextContent(type="text", text="What is the capital of France?")), @@ -112,8 +112,6 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa ] ) - server = Server("prompter", on_get_prompt=get_prompt) - async with connect(server) as client: result = await client.get_prompt("geography_quiz") @@ -133,16 +131,16 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa @requirement("prompts:get:no-args") async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + server = Server("prompter") - async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: - assert params.name == "static" - assert params.arguments is None + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> GetPromptResult: + assert name == "static" + assert arguments is None return GetPromptResult( messages=[PromptMessage(role="user", content=TextContent(type="text", text="Say hello."))] ) - server = Server("prompter", on_get_prompt=get_prompt) - async with connect(server) as client: result = await client.get_prompt("static") @@ -161,9 +159,11 @@ async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" is b"aud") so the snapshot pins the exact bytes. """ + server = Server("prompter") - async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: - assert params.name == "media" + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> GetPromptResult: + assert name == "media" return GetPromptResult( messages=[ PromptMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")), @@ -178,8 +178,6 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa ] ) - server = Server("prompter", on_get_prompt=get_prompt) - async with connect(server) as client: result = await client.get_prompt("media", {}) @@ -202,18 +200,18 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa @requirement("prompts:get:unknown-name") async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: - """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + """A handler that rejects an unrecognised prompt name with McpError produces a JSON-RPC error. The error's code and message chosen by the handler reach the client verbatim. """ + server = Server("prompter") - async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: - raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") - - server = Server("prompter", on_get_prompt=get_prompt) + @server.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> GetPromptResult: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown prompt: {name}")) async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.get_prompt("nope") assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 1ac8aa5812..fdda6c4f61 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -6,7 +6,7 @@ import pytest from inline_snapshot import snapshot -from mcp import McpError, types +from mcp import McpError from mcp.server.lowlevel import Server from mcp.types import ( INVALID_PARAMS, @@ -185,21 +185,23 @@ async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Con "additionalProperties": False, } - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult(tools=[Tool(name="typed", inputSchema=schema)]) + server = Server("typed") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "typed" - assert params.arguments == {"count": 3, "options": {"verbose": True}} - return CallToolResult(content=[TextContent(type="text", text="ok")]) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="typed", inputSchema=schema)] - server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "typed" + assert arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(type="text", text="ok")]) async with connect(server) as client: listed = await client.list_tools() called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) - assert listed.tools[0].input_schema == schema + assert listed.tools[0].inputSchema == schema assert called == snapshot(CallToolResult(content=[TextContent(type="text", text="ok")])) @@ -218,10 +220,11 @@ async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: _meta={"example.com/source": "interaction-suite"}, ) - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult(tools=[tool]) + server = Server("annotated") - server = Server("annotated", on_list_tools=list_tools) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [tool] async with connect(server) as client: result = await client.list_tools() @@ -256,11 +259,15 @@ async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: snapshot pins the exact bytes the client receives. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult(tools=[Tool(name="render", inputSchema={"type": "object"})]) + server = Server("renderer") + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="render", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "render" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "render" return CallToolResult( content=[ TextContent(type="text", text="all five content block types"), @@ -276,8 +283,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ] ) - server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("render", {}) @@ -303,14 +308,16 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_call_tool_structured_content(connect: Connect) -> None: """A tool result carrying structured content alongside content delivers both to the client.""" - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult(tools=[Tool(name="sum", inputSchema={"type": "object"})]) + server = Server("calculator") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "sum" - return CallToolResult(content=[TextContent(type="text", text="the sum is 5")], structuredContent={"sum": 5}) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="sum", inputSchema={"type": "object"})] - server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "sum" + return CallToolResult(content=[TextContent(type="text", text="the sum is 5")], structuredContent={"sum": 5}) async with connect(server) as client: result = await client.call_tool("sum", {}) @@ -333,21 +340,22 @@ async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> release = anyio.Event() results: dict[str, CallToolResult] = {} - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) + server = Server("echoer") + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="echo", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "echo" - assert params.arguments is not None - tag = params.arguments["tag"] + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "echo" + tag = arguments["tag"] assert isinstance(tag, str) started.append(tag) started_events[tag].set() await release.wait() return CallToolResult(content=[TextContent(type="text", text=tag)]) - server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: # pragma: no branch @@ -376,31 +384,35 @@ async def call_and_record(tag: str) -> None: async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: """A result whose structured content does not conform to the tool's declared output schema never reaches the caller: the client validates it against the schema cached from tools/list and raises. + + The handler returns a full `CallToolResult`, which the v1 lowlevel `@server.call_tool()` decorator + passes through unchanged (the decorator's own output-schema validation only runs when the handler + returns a bare dict/tuple/iterable), so the malformed structured content reaches the client and the + client-side check is what raises. """ + server = Server("weather") - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult( - tools=[ - Tool( - name="forecast", - inputSchema={"type": "object"}, - outputSchema={ - "type": "object", - "properties": {"temperature": {"type": "number"}}, - "required": ["temperature"], - }, - ) - ] - ) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="forecast", + inputSchema={"type": "object"}, + outputSchema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "forecast" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "forecast" return CallToolResult( content=[TextContent(type="text", text="warm")], structuredContent={"temperature": "warm"} ) - server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: await client.list_tools() with pytest.raises(RuntimeError) as exc_info: @@ -417,30 +429,29 @@ async def test_is_error_result_bypasses_client_output_schema_validation(connect: The schema is cached up front so the client could validate, proving the bypass is specifically the isError flag and not an empty cache. """ + server = Server("weather") - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult( - tools=[ - Tool( - name="forecast", - inputSchema={"type": "object"}, - outputSchema={ - "type": "object", - "properties": {"temperature": {"type": "number"}}, - "required": ["temperature"], - }, - ) - ] - ) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="forecast", + inputSchema={"type": "object"}, + outputSchema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "forecast" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "forecast" return CallToolResult( content=[TextContent(type="text", text="boom")], structuredContent={"temperature": "warm"}, isError=True ) - server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: await client.list_tools() result = await client.call_tool("forecast", {}) @@ -458,24 +469,23 @@ async def test_declared_output_schema_with_no_structured_content_is_rejected_by_ The error is the SDK's own message, so the full text is snapshotted. """ + server = Server("weather") - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body - return ListToolsResult( - tools=[ - Tool( - name="forecast", - inputSchema={"type": "object"}, - outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, - ) - ] - ) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="forecast", + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "forecast" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "forecast" return CallToolResult(content=[TextContent(type="text", text="warm")]) - server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: await client.list_tools() with pytest.raises(RuntimeError) as exc_info: @@ -486,39 +496,41 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("client:output-schema:auto-list") async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: - """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + """Calling a tool whose schema is not cached issues an implicit tools/list to populate it. The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the - second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of - the surprising behaviour where a server with only on_call_tool sees a successful call answered - with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + second call hits the cache and does not. On v1 the server-side `@server.call_tool()` decorator + also primes its own cache by invoking the registered list_tools handler internally on a miss, so + the first call records the handler running twice (server-side priming + the client's implicit + tools/list request) rather than once. The second call hits both caches and records nothing + further, which is what proves the client's behaviour. See the divergence on the requirement. """ list_calls: list[str] = [] - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: # noqa: F821 -- batch 2/3 rewrites this body + server = Server("weather") + + @server.list_tools() + async def list_tools() -> list[Tool]: list_calls.append("called") - return ListToolsResult( - tools=[ - Tool( - name="forecast", - inputSchema={"type": "object"}, - outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, - ) - ] - ) + return [ + Tool( + name="forecast", + inputSchema={"type": "object"}, + outputSchema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: # noqa: F821 -- batch 2/3 rewrites this body - assert params.name == "forecast" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "forecast" return CallToolResult(content=[TextContent(type="text", text="21 C")], structuredContent={"temperature": 21}) - server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: first = await client.call_tool("forecast", {}) - assert list_calls == ["called"] + assert list_calls == ["called", "called"] second = await client.call_tool("forecast", {}) - assert list_calls == ["called"] + assert list_calls == ["called", "called"] assert first == snapshot( CallToolResult(content=[TextContent(type="text", text="21 C")], structuredContent={"temperature": 21}) ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 657cc2a6c5..56b992e3e0 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -1,10 +1,10 @@ -"""Prompt interactions against MCPServer, driven through the public Client API.""" +"""Prompt interactions against FastMCP, driven through the public client API.""" import pytest from inline_snapshot import snapshot -from mcp import MCPError -from mcp.server.mcpserver import MCPServer +from mcp import McpError +from mcp.server.fastmcp import FastMCP from mcp.types import ( ErrorData, GetPromptResult, @@ -26,7 +26,7 @@ async def test_list_prompts_derives_arguments_from_signature(connect: Connect) - Parameters without a default are required; the description comes from the docstring. """ - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def code_review(code: str, style_guide: str = "pep8") -> str: @@ -55,7 +55,7 @@ def code_review(code: str, style_guide: str = "pep8") -> str: @requirement("mcpserver:prompt:decorated") async def test_get_prompt_renders_function_return(connect: Connect) -> None: """The decorated function's string return value is rendered as a single user message.""" - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def greet(name: str) -> str: @@ -80,7 +80,7 @@ async def test_get_unknown_prompt_is_error(connect: Connect) -> None: The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on the requirement). """ - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def greet(name: str) -> str: @@ -88,7 +88,7 @@ def greet(name: str) -> str: raise NotImplementedError async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.get_prompt("nope") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) @@ -102,7 +102,7 @@ async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Invalid params is reported as error code 0 with the bare exception text (see the divergence note on the requirement). """ - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def greet(name: str) -> str: @@ -110,7 +110,7 @@ def greet(name: str) -> str: raise NotImplementedError async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.get_prompt("greet") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) @@ -125,7 +125,7 @@ async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_func raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable rendering-error prefix; the body of the message is raw pydantic output and is not asserted. """ - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def repeat(phrase: str, count: int) -> str: @@ -133,7 +133,7 @@ def repeat(phrase: str, count: int) -> str: raise NotImplementedError async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) assert exc_info.value.error.code == 0 @@ -143,7 +143,7 @@ def repeat(phrase: str, count: int) -> str: @requirement("mcpserver:prompt:optional-args") async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: """A prompt rendered without one of its optional arguments uses that parameter's default value.""" - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def review(code: str, style: str = "pep8") -> str: @@ -165,12 +165,12 @@ def review(code: str, style: str = "pep8") -> str: async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: """Registering a second prompt with an already-used name keeps the first registration. - The intended behaviour is rejection at registration time; MCPServer instead logs a warning + The intended behaviour is rejection at registration time; FastMCP instead logs a warning and discards the second registration (see the divergence note on the requirement). The second function is registered via the decorator with an explicit name so the test does not redefine the same function name in this scope. """ - mcp = MCPServer("prompter") + mcp = FastMCP("prompter") @mcp.prompt() def greet() -> str: diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index 9c7353dda5..5fac87a628 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -14,7 +14,7 @@ import pytest from inline_snapshot import snapshot -from mcp.client.client import Client +from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.types import EmptyResult @@ -29,8 +29,11 @@ @requirement("transport:sse:endpoint-event") async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: """Connecting opens a GET stream whose first event names the POST endpoint and a fresh - session id; messages POSTed there are answered on that stream, and disconnecting releases the - server's session entry.""" + session id; messages POSTed there are answered on that stream. + + On v1 the server's session entry is not removed on disconnect (`SseServerTransport` never + pops `_read_stream_writers[session_id]`); the final assertion pins that behaviour. + """ app, sse = build_sse_app(Server("legacy")) captured_session_id: list[str] = [] @@ -47,16 +50,17 @@ def httpx_client_factory( auth=auth, ) - transport = sse_client( - f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append - ) with anyio.fail_after(5): - async with Client(transport) as client: - assert len(captured_session_id) == 1 - assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers - assert await client.send_ping() == snapshot(EmptyResult()) - - assert sse._read_stream_writers == {} + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) as (read, write): + async with ClientSession(read, write) as client: + await client.initialize() + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers @requirement("transport:sse:post:session-routing") diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index a0b7fb2bda..041948733a 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -7,14 +7,17 @@ through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. """ +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from pydantic import BaseModel +from pydantic import AnyUrl, BaseModel -from mcp.client import ClientRequestContext +from mcp.client.session import ClientSession from mcp.server.elicitation import AcceptedElicitation -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.context import RequestContext from mcp.types import ( CallToolResult, ElicitRequestParams, @@ -23,6 +26,7 @@ LoggingMessageNotificationParams, ResourceUpdatedNotification, ResourceUpdatedNotificationParams, + ServerNotification, TextContent, ) from tests.interaction._connect import connect_over_streamable_http @@ -32,9 +36,9 @@ pytestmark = pytest.mark.anyio -def _smoke_server() -> MCPServer: +def _smoke_server() -> FastMCP: """A server exercising each message shape the transport-specific tests need.""" - mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + mcp = FastMCP("smoke", instructions="Talk to the smoke server.") @mcp.tool() def echo(text: str) -> str: @@ -56,7 +60,7 @@ async def ask(ctx: Context) -> str: async def announce(ctx: Context) -> str: """Send one notification related to this request and one that is not.""" await ctx.info("about to announce") - await ctx.session.send_resource_updated("file:///watched.txt") + await ctx.session.send_resource_updated(AnyUrl("file:///watched.txt")) return "announced" return mcp @@ -67,7 +71,6 @@ async def announce(ctx: Context) -> str: async def test_tool_call_over_streamable_http_with_json_responses() -> None: """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: - assert client.initialize_result.server_info.name == "smoke" result = await client.call_tool("echo", {"text": "as json"}) assert result == snapshot( @@ -90,19 +93,6 @@ async def test_tool_calls_over_stateless_streamable_http() -> None: ) -@requirement("transport:streamable-http:stateless-restrictions") -async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: - """A handler that tries to call back to the client in stateless mode fails: there is no session.""" - async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: - result = await client.call_tool("ask", {}) - - assert result.is_error is True - assert isinstance(result.content[0], TextContent) - # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error - # path; pin the stable prefix rather than the full exception prose. - assert result.content[0].text.startswith("Error executing tool ask:") - - @requirement("transport:streamable-http:notifications") @requirement("transport:streamable-http:unrelated-messages") @requirement("hosting:http:standalone-sse") @@ -120,7 +110,7 @@ async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> No async def collect(message: IncomingMessage) -> None: received.append(message) - if isinstance(message, ResourceUpdatedNotification): + if isinstance(message, ServerNotification) and isinstance(message.root, ResourceUpdatedNotification): resource_update_seen.set() async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: @@ -132,12 +122,16 @@ async def collect(message: IncomingMessage) -> None: CallToolResult(content=[TextContent(type="text", text="announced")], structuredContent={"result": "announced"}) ) # The related log notification rides the call's stream; the unrelated resource-updated - # notification rides the standalone stream. Both arrive, nothing else does. - assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( - [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + # notification rides the standalone stream. Both arrive, nothing else does. v1's + # message_handler receives the ServerNotification RootModel; unwrap to the inner type and + # compare on `.params` (the received envelope carries `jsonrpc` in __pydantic_extra__ via + # `extra="allow"`, which makes full-object equality noisy without adding coverage). + notifications = [m.root for m in received if isinstance(m, ServerNotification)] + assert [n.params for n in notifications if isinstance(n, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotificationParams(level="info", data="about to announce")] ) - assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( - [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + assert [n.params for n in notifications if isinstance(n, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotificationParams(uri=AnyUrl("file:///watched.txt"))] ) assert len(received) == 2 @@ -153,7 +147,7 @@ async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> """ asked: list[ElicitRequestParams] = [] - async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + async def answer(context: RequestContext[ClientSession, Any], params: ElicitRequestParams) -> ElicitResult: asked.append(params) return ElicitResult(action="accept", content={"confirmed": True}) From 7251516192edb70a9c32ed7792bacf0fe5fc801f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:02:43 +0000 Subject: [PATCH 11/19] =?UTF-8?q?backport:=20phase-4=20waves=203+4=20+=20S?= =?UTF-8?q?2-S9=20(Workflow,=2043=20agents)=20=E2=80=94=20~174=20pass=20/?= =?UTF-8?q?=203=20deferred?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parallel lane (14 N-risk files) + sequential lane (8 Y-risk files), all batches ≤5 tests. Deferred: roots:list-changed (no v1 lowlevel decorator), resources:templates:pagination (nullary-only), client-auth:authorize:offline-access-consent (v1 lacks SEP-2207). Several expected gaps did not materialize (InMemoryTransport, REQUEST_TIMEOUT, FastMCP ctor kwargs, verify_tokens=False) — all solvable with helpers or snapshot regen. --- tests/interaction/_requirements.py | 19 + tests/interaction/auth/test_as_handlers.py | 5 +- .../interaction/auth/test_authorize_token.py | 50 +-- tests/interaction/auth/test_discovery.py | 40 +- tests/interaction/auth/test_lifecycle.py | 48 ++- .../interaction/lowlevel/test_elicitation.py | 403 +++++++++--------- tests/interaction/lowlevel/test_initialize.py | 281 ++++++------ .../interaction/lowlevel/test_list_changed.py | 91 ++-- tests/interaction/lowlevel/test_pagination.py | 152 +++---- tests/interaction/lowlevel/test_resources.py | 248 +++++------ tests/interaction/lowlevel/test_roots.py | 128 +++--- tests/interaction/lowlevel/test_sampling.py | 382 ++++++++--------- tests/interaction/lowlevel/test_timeouts.py | 90 ++-- tests/interaction/lowlevel/test_wire.py | 176 +++++--- tests/interaction/mcpserver/test_context.py | 53 +-- tests/interaction/mcpserver/test_resources.py | 45 +- tests/interaction/mcpserver/test_tools.py | 70 +-- .../transports/test_client_transport_http.py | 59 ++- tests/interaction/transports/test_flows.py | 16 +- .../transports/test_hosting_http.py | 66 +-- .../transports/test_hosting_resume.py | 48 ++- .../transports/test_hosting_session.py | 33 +- 22 files changed, 1285 insertions(+), 1218 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 312cdead56..629a05b9e3 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -794,6 +794,12 @@ def __post_init__(self) -> None: "A server with resource handlers advertises the resources capability, including the subscribe " "sub-flag when a subscribe handler is registered." ), + divergence=Divergence( + note=( + "The low-level Server hard-codes subscribe=False in get_capabilities() regardless of " + "whether a subscribe_resource handler is registered." + ), + ), ), "resources:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", @@ -857,6 +863,11 @@ def __post_init__(self) -> None: "resources:templates:pagination": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", behavior="resources/templates/list supports cursor pagination.", + deferred=( + "Not expressible via the v1 public API: Server.list_resource_templates() only accepts a nullary " + "() -> list[ResourceTemplate] handler with no dual-signature dispatch, so the inbound cursor is " + "unreadable and the handler cannot return a nextCursor." + ), ), "resources:unsubscribe": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", @@ -1439,6 +1450,10 @@ def __post_init__(self) -> None: "roots:list-changed": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + deferred=( + "Not expressible via the v1 public API: the low-level Server exposes no decorator for " + "notifications/roots/list_changed, so a server handler cannot be registered to observe delivery." + ), ), "roots:list-changed:client-emits": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", @@ -2483,6 +2498,10 @@ def __post_init__(self) -> None: "and prompt=consent is added to the authorize request." ), transports=("streamable-http",), + deferred=( + "Not expressible via the v1 public API: v1's OAuthClientProvider has no SEP-2207 " + "offline_access auto-append or prompt=consent logic." + ), ), "client-auth:bearer-header:every-request": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py index 5cb4e92d86..e8d6e71b18 100644 --- a/tests/interaction/auth/test_as_handlers.py +++ b/tests/interaction/auth/test_as_handlers.py @@ -234,7 +234,10 @@ async def test_registration_with_invalid_metadata_is_rejected_with_400( no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) assert no_auth_code.status_code == 400 assert no_auth_code.json() == snapshot( - {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + { + "error": "invalid_client_metadata", + "error_description": "grant_types must be authorization_code and refresh_token", + } ) bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index 8e44cffc04..276136d710 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -1,8 +1,9 @@ """Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. -Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on -the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what -the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +Every test connects a real `ClientSession` end to end via `connect_with_oauth`; the assertions +are on the parsed authorize URL and the recorded `/token` form body, because those wire shapes +are what the spec mandates and the session cannot observe them. The recording uses +`record_requests`, which snapshots each request at send time so the auth flow's in-place header mutation on retry never affects what was captured for the first attempt. @@ -24,11 +25,10 @@ from inline_snapshot import snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp import types from mcp.client.auth import OAuthFlowError -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata -from mcp.types import ListToolsResult, Tool +from mcp.types import Tool from tests.interaction._connect import BASE_URL from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( @@ -50,8 +50,15 @@ ASM_PATH = "/.well-known/oauth-authorization-server" -async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) +def make_guarded_server() -> Server: + """Build a lowlevel `Server` exposing one `echo` tool, to mount behind the OAuth-gated MCP endpoint.""" + server = Server("guarded") + + @server.list_tools() + async def _list_tools() -> list[Tool]: + return [Tool(name="echo", inputSchema={"type": "object"})] + + return server def authorize_params(authorize_url: str) -> dict[str, str]: @@ -91,13 +98,12 @@ def token_request(self) -> RecordedRequest: async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: """Run one full OAuth connect with default configuration and yield its recorded wire traffic. - `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's - SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + `valid_scopes` includes `offline_access` so the AS metadata advertises it; `required_scopes` stays at `["mcp"]` so the issued token still passes the bearer middleware. """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) with anyio.fail_after(5): @@ -113,7 +119,6 @@ async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: @requirement("client-auth:pkce:s256") @requirement("client-auth:resource-parameter") -@requirement("client-auth:authorize:offline-access-consent") async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( recorded_oauth_flow: RecordedFlow, ) -> None: @@ -130,7 +135,6 @@ async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( "client_id", "code_challenge", "code_challenge_method", - "prompt", "redirect_uri", "resource", "response_type", @@ -145,9 +149,7 @@ async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( # the stable prefix so the test does not lock in a trailing-slash decision. assert params["resource"].startswith(BASE_URL) assert params["state"] != "" - - assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) - assert params["prompt"] == "consent" + assert params["scope"].split(" ") == snapshot(["mcp"]) @requirement("client-auth:pkce:s256") @@ -175,15 +177,15 @@ async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: random tokens). """ provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() headless = HeadlessOAuth(state_override="wrong-state") with anyio.fail_after(5): with pytest.RaisesGroup( pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True ): - # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), - # so an `async with` body would be unreachable; entering explicitly avoids dead code. + # Entering the connect raises during the OAuth handshake (before `ClientSession.initialize` + # returns), so an `async with` body would be unreachable; entering explicitly avoids dead code. await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() @@ -238,7 +240,7 @@ async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_ """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() client_info = OAuthClientInformationFull( client_id="cid", @@ -276,7 +278,7 @@ async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_adve """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() override = OAuthMetadata( issuer=AnyHttpUrl(f"{BASE_URL}/"), @@ -319,7 +321,7 @@ async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_me """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' @@ -360,7 +362,7 @@ async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_ serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() with anyio.fail_after(5): async with connect_with_oauth( @@ -386,7 +388,7 @@ async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_tok """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(deny_authorize=True) - server = Server("guarded", on_list_tools=list_tools) + server = make_guarded_server() headless = HeadlessOAuth() with anyio.fail_after(5): diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py index afa3d0cd4b..1f80d981f1 100644 --- a/tests/interaction/auth/test_discovery.py +++ b/tests/interaction/auth/test_discovery.py @@ -1,13 +1,13 @@ """Protected-resource and authorization-server metadata discovery, end to end. -Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the -recorded request paths the discovery probes produced; the discovery URL ordering is a wire -detail `Client` cannot observe directly but the recording can. Tests that need a metadata +Every client-side test connects a real `ClientSession` via `connect_with_oauth` and asserts on +the recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail the session cannot observe directly but the recording can. Tests that need a metadata endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. -The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their -assertions are the metadata response bodies and headers, which `Client` does not surface. +The two server-side tests drive raw httpx against `mounted_app` because their assertions are +the metadata response bodies and headers, which `ClientSession` does not surface. """ import json @@ -17,11 +17,10 @@ from inline_snapshot import snapshot from pydantic import AnyHttpUrl -from mcp import types from mcp.client.auth import OAuthFlowError, OAuthRegistrationError -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata -from mcp.types import ListToolsResult, Tool +from mcp.types import Tool from tests.interaction._connect import BASE_URL, mounted_app from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( @@ -42,8 +41,15 @@ OIDC_ROOT = "/.well-known/openid-configuration" -async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="probe", inputSchema={"type": "object"})]) +def guarded_server() -> Server: + """Build a lowlevel `Server` exposing a single `probe` tool via `list_tools`.""" + server = Server("guarded") + + @server.list_tools() + async def _list_tools() -> list[Tool]: + return [Tool(name="probe", inputSchema={"type": "object"})] + + return server def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: @@ -74,7 +80,7 @@ async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticat """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() with anyio.fail_after(5): async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): @@ -97,7 +103,7 @@ async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() prm = ProtectedResourceMetadata( resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] @@ -132,7 +138,7 @@ async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_th """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) with anyio.fail_after(5): @@ -160,7 +166,7 @@ async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_e """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() app_shim = shim(serve={"/register": (400, error_body)}) @@ -185,7 +191,7 @@ async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() prm = ProtectedResourceMetadata( resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] @@ -237,7 +243,7 @@ async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() prm = ProtectedResourceMetadata( resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] @@ -320,7 +326,7 @@ async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_pro unknown-field tolerance. """ provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = guarded_server() metadata = real_asm() metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py index 8812dccead..72e4982d4e 100644 --- a/tests/interaction/auth/test_lifecycle.py +++ b/tests/interaction/auth/test_lifecycle.py @@ -11,15 +11,15 @@ from urllib.parse import parse_qsl, urlsplit import anyio +import httpx import pytest from inline_snapshot import snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp import MCPError, types from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata -from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from mcp.types import Tool from tests.interaction._connect import BASE_URL from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( @@ -43,8 +43,15 @@ CIMD_URL = "https://client.example/.well-known/mcp-client" -async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", inputSchema={"type": "object"})]) +def _guarded_server() -> Server[object]: + """Build a lowlevel server exposing a single `echo` tool, in the v1 decorator style.""" + server: Server[object] = Server("guarded") + + @server.list_tools() + async def _list_tools() -> list[Tool]: + return [Tool(name="echo", inputSchema={"type": "object"})] + + return server def form_body(request: RecordedRequest) -> dict[str, str]: @@ -110,7 +117,7 @@ async def test_an_expired_access_token_is_transparently_refreshed_before_the_nex recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() with anyio.fail_after(5): async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): @@ -149,7 +156,7 @@ async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challe recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) challenge = 'Bearer error="insufficient_scope", scope="mcp write"' @@ -184,20 +191,21 @@ async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_loopin The provider rejects every token at verification, so the full flow runs once and the retry is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the - transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, - registration, authorize, and token each ran exactly once: no loop. + transport raises it as an `httpx.HTTPStatusError` during connect. Discovery, registration, + authorize, and token each ran exactly once: no loop. """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() - def is_internal_error(error: MCPError) -> bool: - return error.error.code == INTERNAL_ERROR + def is_unauthorized(error: httpx.HTTPStatusError) -> bool: + return error.response.status_code == 401 with anyio.fail_after(5): - with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): - # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), - # so an `async with` body would be unreachable; entering explicitly avoids dead code. + with pytest.RaisesGroup(pytest.RaisesExc(httpx.HTTPStatusError, check=is_unauthorized), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (the harness calls + # `session.initialize()` before yielding), so an `async with` body would be + # unreachable; entering explicitly avoids dead code. await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() counts = path_counts(recorded) @@ -224,7 +232,7 @@ async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_ur provider = InMemoryAuthorizationServerProvider() seeded_client(provider, client_id=CIMD_URL) storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() with anyio.fail_after(5): async with connect_with_oauth( @@ -266,7 +274,7 @@ async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow( recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) storage = InMemoryTokenStorage() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() with anyio.fail_after(5): async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): @@ -301,7 +309,7 @@ async def test_client_credentials_provider_obtains_a_token_without_an_authorize_ """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() auth = ClientCredentialsOAuthProvider( server_url=f"{BASE_URL}/mcp", @@ -346,7 +354,7 @@ async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_ """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() audiences: list[str] = [] @@ -408,7 +416,7 @@ async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider() - server = Server("guarded", on_list_tools=list_tools) + server = _guarded_server() storage = InMemoryTokenStorage() expected_client_id: str diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 82c81aa24b..3c6c702b95 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -4,13 +4,16 @@ mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. """ +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from mcp import MCPError, UrlElicitationRequiredError, types -from mcp.client import ClientRequestContext, ClientSession -from mcp.server import Server, ServerRequestContext +from mcp import McpError, UrlElicitationRequiredError, types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.context import RequestContext from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( @@ -58,43 +61,31 @@ async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) """ received: list[types.ElicitRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] - ) + server = Server("registrar") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "signup" - answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) - return CallToolResult(content=[TextContent(type="text", text=answer.action)], structuredContent=answer.content) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] - server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "signup" + answer = await server.request_context.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(type="text", text=answer.action)], structuredContent=answer.content) - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: received.append(params) return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("signup", {}) - assert received == snapshot( - [ - ElicitRequestFormParams( - _meta={}, - message="Choose a username.", - requestedSchema={ - "type": "object", - "properties": { - "username": {"type": "string"}, - "newsletter": {"type": "boolean"}, - }, - "required": ["username"], - }, - ) - ] - ) + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].message == "Choose a username." + assert received[0].requestedSchema == REQUESTED_SCHEMA assert result == snapshot( CallToolResult( content=[TextContent(type="text", text="accept")], @@ -107,21 +98,21 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: """A declined form elicitation returns the decline action to the handler with no content.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] - ) + server = Server("confirmer") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "confirm" - answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) - return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] - server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "confirm" + answer = await server.request_context.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: return ElicitResult(action="decline") async with connect(server, elicitation_callback=answer_form) as client: @@ -134,21 +125,21 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: """A cancelled form elicitation returns the cancel action to the handler with no content.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] - ) + server = Server("confirmer") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "confirm" - answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) - return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="confirm", description="Ask for confirmation.", inputSchema={"type": "object"})] - server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "confirm" + answer = await server.request_context.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: return ElicitResult(action="cancel") async with connect(server, elicitation_callback=answer_form) as client: @@ -169,23 +160,21 @@ async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: elicitation capability before sending (see the divergence on `server-respects-mode`). """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="ask", description="Ask the user.", inputSchema={"type": "object"})] - ) + server = Server("asker") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask", description="Ask the user.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "ask" try: - await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) - except MCPError as exc: + await server.request_context.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except McpError as exc: return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) raise NotImplementedError # elicit_form cannot succeed without a client callback - server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("ask", {}) @@ -205,39 +194,34 @@ async def test_elicit_url_delivers_url_and_returns_accept_without_content(connec """ received: list[types.ElicitRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] - ) + server = Server("authorizer") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "authorize" - answer = await ctx.session.elicit_url( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "authorize" + answer = await server.request_context.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) - server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) - - async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_url( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: received.append(params) return ElicitResult(action="accept") async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) - assert received == snapshot( - [ - ElicitRequestURLParams( - _meta={}, - message="Authorize access to your calendar.", - url="https://example.com/oauth/authorize", - elicitationId="auth-001", - ) - ] - ) + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestURLParams) + assert received[0].message == "Authorize access to your calendar." + assert received[0].url == "https://example.com/oauth/authorize" + assert received[0].elicitationId == "auth-001" assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="accept content=None")])) @@ -245,23 +229,23 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: """A declined URL elicitation returns the decline action to the handler with no content.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] - ) + server = Server("authorizer") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "authorize" - answer = await ctx.session.elicit_url( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "authorize" + answer = await server.request_context.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) - server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) - - async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_url( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: return ElicitResult(action="decline") async with connect(server, elicitation_callback=answer_url) as client: @@ -274,23 +258,23 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: """A cancelled URL elicitation returns the cancel action to the handler with no content.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] - ) + server = Server("authorizer") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "authorize" - answer = await ctx.session.elicit_url( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="authorize", description="Link an account.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "authorize" + answer = await server.request_context.session.elicit_url( "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" ) return CallToolResult(content=[TextContent(type="text", text=f"{answer.action} content={answer.content}")]) - server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) - - async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_url( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: return ElicitResult(action="cancel") async with connect(server, elicitation_callback=answer_url) as client: @@ -317,15 +301,16 @@ async def test_elicitation_complete_notification_carries_the_elicited_id_back_to async def collect(message: IncomingMessage) -> None: received.append(message) - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="link_account", description="Link an account.", inputSchema={"type": "object"})] - ) + server = Server("authorizer") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="link_account", description="Link an account.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "link_account" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "link_account" + ctx = server.request_context answer = await ctx.session.elicit_url( "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id ) @@ -333,11 +318,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) return CallToolResult(content=[TextContent(type="text", text="linked")]) - server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) - - async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_url( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: assert isinstance(params, ElicitRequestURLParams) - elicited_ids.append(params.elicitation_id) + elicited_ids.append(params.elicitationId) return ElicitResult(action="accept") async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: @@ -345,9 +330,10 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP # The completion notification refers to the same elicitation the client accepted. assert elicited_ids == [elicitation_id] - assert received == snapshot( - [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitationId="auth-001"))] - ) + assert len(received) == 1 + assert isinstance(received[0], types.ServerNotification) + assert isinstance(received[0].root, ElicitCompleteNotification) + assert received[0].root.params == ElicitCompleteNotificationParams(elicitationId="auth-001") @requirement("elicitation:url:required-error") @@ -360,8 +346,11 @@ async def test_url_elicitation_required_error_carries_pending_elicitations(conne notifications, and retry the original request. """ - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "read_files" + server = Server("authorizer") + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "read_files" raise UrlElicitationRequiredError( [ ElicitRequestURLParams( @@ -372,10 +361,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ] ) - server = Server("authorizer", on_call_tool=call_tool) - async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.call_tool("read_files", {}) assert exc_info.value.error == snapshot( @@ -427,23 +414,23 @@ async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the "required": ["email"], } - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="onboard", description="Onboard the user.", inputSchema={"type": "object"})] - ) + server = Server("onboarder") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "onboard" - answer = await ctx.session.elicit_form("Tell us about yourself.", schema) - return CallToolResult(content=[TextContent(type="text", text=answer.action)]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="onboard", description="Onboard the user.", inputSchema={"type": "object"})] - server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "onboard" + answer = await server.request_context.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(type="text", text=answer.action)]) received: list[types.ElicitRequestParams] = [] - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: received.append(params) return ElicitResult(action="accept", content={"email": "ada@example.com"}) @@ -452,7 +439,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest assert len(received) == 1 assert isinstance(received[0], ElicitRequestFormParams) - assert received[0].requested_schema == schema + assert received[0].requestedSchema == schema @requirement("elicitation:form:schema:restricted-subset") @@ -477,23 +464,23 @@ async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: }, } - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="profile", description="Collect a profile.", inputSchema={"type": "object"})] - ) + server = Server("profiler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "profile" - answer = await ctx.session.elicit_form("Profile details.", schema) - return CallToolResult(content=[TextContent(type="text", text=answer.action)]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="profile", description="Collect a profile.", inputSchema={"type": "object"})] - server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "profile" + answer = await server.request_context.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(type="text", text=answer.action)]) received: list[types.ElicitRequestParams] = [] - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: received.append(params) return ElicitResult(action="decline") @@ -502,7 +489,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest assert len(received) == 1 assert isinstance(received[0], ElicitRequestFormParams) - assert received[0].requested_schema == schema + assert received[0].requestedSchema == schema @requirement("elicitation:form:response-validation") @@ -516,24 +503,24 @@ async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the the requirement), so the handler observes exactly what the callback sent. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] - ) + server = Server("registrar") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "signup" - answer = await ctx.session.elicit_form( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="signup", description="Register the user.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "signup" + answer = await server.request_context.session.elicit_form( "Choose a name.", {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, ) return CallToolResult(content=[TextContent(type="text", text=answer.action)], structuredContent=answer.content) - server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) - - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) async with connect(server, elicitation_callback=answer_form) as client: @@ -555,20 +542,19 @@ async def test_elicitation_complete_for_an_unknown_id_is_received_without_error( notification as-is. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="noop", description="Send a stray complete.", inputSchema={"type": "object"})] - ) + server = Server("notifier") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "noop" + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="noop", description="Send a stray complete.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "noop" + ctx = server.request_context await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) return CallToolResult(content=[TextContent(type="text", text="ok")]) - server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) - received: list[IncomingMessage] = [] async def collect(message: IncomingMessage) -> None: @@ -578,9 +564,10 @@ async def collect(message: IncomingMessage) -> None: result = await client.call_tool("noop", {}) assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="ok")])) - assert received == snapshot( - [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitationId="never-elicited"))] - ) + assert len(received) == 1 + assert isinstance(received[0], types.ServerNotification) + assert isinstance(received[0].root, ElicitCompleteNotification) + assert received[0].root.params == ElicitCompleteNotificationParams(elicitationId="never-elicited") @requirement("elicitation:form:mode-omitted-default") @@ -595,7 +582,9 @@ async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None answered = anyio.Event() server_received: list[JSONRPCMessage] = [] - async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + async def answer_form( + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams + ) -> ElicitResult: received.append(params) return ElicitResult(action="accept", content={}) @@ -603,7 +592,7 @@ async def scripted_server(streams: MessageStream) -> None: server_read, server_write = streams initialize = await server_read.receive() assert isinstance(initialize, SessionMessage) - request = initialize.message + request = initialize.message.root assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( @@ -613,25 +602,29 @@ async def scripted_server(streams: MessageStream) -> None: ) await server_write.send( SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) ) ) ) initialized = await server_read.receive() assert isinstance(initialized, SessionMessage) - assert isinstance(initialized.message, JSONRPCNotification) - assert initialized.message.method == "notifications/initialized" + assert isinstance(initialized.message.root, JSONRPCNotification) + assert initialized.message.root.method == "notifications/initialized" # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. await server_write.send( SessionMessage( - JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="elicitation/create", - params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + JSONRPCMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) ) ) ) @@ -650,17 +643,11 @@ async def scripted_server(streams: MessageStream) -> None: await session.initialize() await answered.wait() - assert received == snapshot( - [ - ElicitRequestFormParams( - _meta=None, - message="Legacy ask.", - requestedSchema={"type": "object", "properties": {}}, - ) - ] - ) + assert len(received) == 1 assert isinstance(received[0], ElicitRequestFormParams) assert received[0].mode == "form" + assert received[0].message == "Legacy ask." + assert received[0].requestedSchema == {"type": "object", "properties": {}} assert len(server_received) == 1 - assert isinstance(server_received[0], JSONRPCResponse) - assert server_received[0].id == 2 + assert isinstance(server_received[0].root, JSONRPCResponse) + assert server_received[0].root.id == 2 diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index b81471e5b9..0260975258 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -1,26 +1,35 @@ -"""Initialization handshake against the low-level Server, driven through the public Client API. +"""Initialization handshake against the low-level Server. -The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +The first two tests drive a bare ClientSession by hand because v1's `connect` fixture discards the +`initialize()` return value and `ClientSession` does not cache it; capturing `serverInfo` and +`instructions` therefore requires owning the `initialize()` call. The later tests drive a bare +ClientSession over hand-built memory streams for a different reason: the connected session always performs the full handshake with the latest protocol version, so skipping initialization or -requesting a different version can only be expressed one level down. The final test goes one step -further and plays the server's side of the wire by hand, because no real Server can be made to +requesting a different version can only be expressed one level down. The final tests go one step +further and play the server's side of the wire by hand, because no real Server can be made to answer initialize with an unsupported protocol version. """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + import anyio import pytest from inline_snapshot import snapshot +from pydantic import AnyUrl -from mcp import MCPError, types -from mcp.client import ClientRequestContext, ClientSession -from mcp.client._memory import InMemoryTransport -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.shared.context import RequestContext from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( INVALID_PARAMS, CallToolResult, ClientCapabilities, + ClientRequest, CompletionsCapability, EmptyResult, ErrorData, @@ -29,6 +38,7 @@ InitializeRequest, InitializeRequestParams, InitializeResult, + JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, ListToolsRequest, @@ -46,27 +56,61 @@ pytestmark = pytest.mark.anyio +async def _initialize(server: Server[Any]) -> InitializeResult: + """Connect a bare ClientSession to `server` over in-memory streams and return its initialize result. + + v1's `ClientSession` does not cache the initialize result and the `connect` fixture discards it, + so tests that need `serverInfo` or `instructions` own the `initialize()` call themselves. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options())) + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + initialize_result = await session.initialize() + tg.cancel_scope.cancel() + return initialize_result + + +@asynccontextmanager +async def _bare_session(server: Server[Any]) -> AsyncIterator[ClientSession]: + """Yield an *uninitialized* ClientSession connected to `server` over in-memory streams. + + Unlike the `connect` fixture this does not call `initialize()`, so tests can drive the + handshake (or skip it) themselves. This is the v1 spelling of v2's `InMemoryTransport(server)`. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options())) + async with ClientSession(client_read, client_write) as session: + yield session + tg.cancel_scope.cancel() + + @requirement("lifecycle:initialize:basic") @requirement("lifecycle:initialize:server-info") -async def test_initialize_returns_server_info(connect: Connect) -> None: - """Every identity field the server declares is returned to the client in server_info.""" +async def test_initialize_returns_server_info() -> None: + """Every identity field the server declares is returned to the client in serverInfo. + + v1's low-level `Server` accepts `name`, `version`, `website_url`, and `icons`; it has no + `title` or `description` arguments, so those fields are absent from the result. + """ server = Server( "greeter", version="1.2.3", - title="Greeter", - description="Greets people.", website_url="https://example.com/greeter", icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], ) - async with connect(server) as client: - server_info = client.initialize_result.server_info + initialize_result = await _initialize(server) - assert server_info == snapshot( + assert initialize_result.serverInfo == snapshot( Implementation( name="greeter", - title="Greeter", - description="Greets people.", version="1.2.3", websiteUrl="https://example.com/greeter", icons=[Icon(src="https://example.com/icon.png", mimeType="image/png", sizes=["48x48"])], @@ -75,13 +119,13 @@ async def test_initialize_returns_server_info(connect: Connect) -> None: @requirement("lifecycle:initialize:instructions") -async def test_initialize_returns_instructions(connect: Connect) -> None: +async def test_initialize_returns_instructions() -> None: """Instructions are returned when the server declares them and omitted when it does not.""" - async with connect(Server("guided", instructions="Call the add tool.")) as client: - assert client.initialize_result.instructions == snapshot("Call the add tool.") + initialize_result = await _initialize(Server("guided", instructions="Call the add tool.")) + assert initialize_result.instructions == snapshot("Call the add tool.") - async with connect(Server("unguided")) as client: - assert client.initialize_result.instructions is None + initialize_result = await _initialize(Server("unguided")) + assert initialize_result.instructions is None @requirement("lifecycle:initialize:capabilities:from-handlers") @@ -92,59 +136,56 @@ async def test_initialize_returns_instructions(connect: Connect) -> None: async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: """Each feature area with a registered handler is advertised as a capability. - The in-memory transport connects with default initialization options, so the - list_changed flags are always False regardless of the server's notification behaviour. + The `connect` fixture uses default initialization options, so the listChanged flags are always + False regardless of the server's notification behaviour. v1 also hard-codes `subscribe=False` + even when a `subscribe_resource` handler is registered; the handler is registered here anyway + to pin that divergence. """ + server = Server("full") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: + @server.list_tools() + async def list_tools() -> list[types.Tool]: """Registered only so the tools capability is advertised; never called.""" raise NotImplementedError - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListResourcesResult: + @server.list_resources() + async def list_resources() -> list[types.Resource]: """Registered only so the resources capability is advertised; never called.""" raise NotImplementedError - async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: - """Registered only so the subscribe sub-capability is advertised; never called.""" + @server.subscribe_resource() + async def subscribe_resource(uri: AnyUrl) -> None: + """Registered to show v1 still advertises subscribe=False; never called.""" raise NotImplementedError - async def list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: + @server.list_prompts() + async def list_prompts() -> list[types.Prompt]: """Registered only so the prompts capability is advertised; never called.""" raise NotImplementedError - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + @server.set_logging_level() + async def set_logging_level(level: types.LoggingLevel) -> None: """Registered only so the logging capability is advertised; never called.""" raise NotImplementedError - async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + @server.completion() + async def completion( + ref: types.PromptReference | types.ResourceTemplateReference, + argument: types.CompletionArgument, + context: types.CompletionContext | None, + ) -> types.Completion | None: """Registered only so the completions capability is advertised; never called.""" raise NotImplementedError - server = Server( - "full", - on_list_tools=list_tools, - on_list_resources=list_resources, - on_subscribe_resource=subscribe_resource, - on_list_prompts=list_prompts, - on_set_logging_level=set_logging_level, - on_completion=completion, - ) - async with connect(server) as client: - capabilities = client.initialize_result.capabilities + capabilities = client.get_server_capabilities() assert capabilities == snapshot( ServerCapabilities( experimental={}, logging=LoggingCapability(), prompts=PromptsCapability(listChanged=False), - resources=ResourcesCapability(subscribe=True, listChanged=False), + resources=ResourcesCapability(subscribe=False, listChanged=False), tools=ToolsCapability(listChanged=False), completions=CompletionsCapability(), ) @@ -155,29 +196,28 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: """A server with no feature handlers advertises no feature capabilities.""" async with connect(Server("bare")) as client: - capabilities = client.initialize_result.capabilities + capabilities = client.get_server_capabilities() assert capabilities == snapshot(ServerCapabilities(experimental={})) @requirement("lifecycle:initialize:client-info") async def test_initialize_server_sees_client_info(connect: Connect) -> None: - """The client identity supplied to Client is visible to server handlers after initialization.""" + """The client identity supplied to ClientSession is visible to server handlers after initialization.""" + server = Server("introspector") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="whoami", description="Report the caller.", inputSchema={"type": "object"})] - ) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="whoami", description="Report the caller.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "whoami" - assert ctx.session.client_params is not None - client_info = ctx.session.client_params.client_info - return CallToolResult(content=[TextContent(type="text", text=f"{client_info.name} {client_info.version}")]) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "whoami" + client_params = server.request_context.session.client_params + assert client_params is not None + client_info = client_params.clientInfo + return [TextContent(type="text", text=f"{client_info.name} {client_info.version}")] - server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: result = await client.call_tool("whoami", {}) @@ -187,36 +227,34 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("lifecycle:initialize:client-capabilities") async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: """The client capabilities visible to the server reflect which callbacks the client configured.""" - - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="abilities", description="Report capabilities.", inputSchema={"type": "object"})] - ) - - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "abilities" - assert ctx.session.client_params is not None - capabilities = ctx.session.client_params.capabilities + server = Server("introspector") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="abilities", description="Report capabilities.", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "abilities" + client_params = server.request_context.session.client_params + assert client_params is not None + capabilities = client_params.capabilities declared = [ - name - for name, value in ( + label + for label, value in ( ("sampling", capabilities.sampling), ("elicitation", capabilities.elicitation), ) if value is not None ] if capabilities.roots is not None: - declared.append(f"roots(list_changed={capabilities.roots.list_changed})") - return CallToolResult(content=[TextContent(type="text", text=",".join(declared) or "none")]) + declared.append(f"roots(list_changed={capabilities.roots.listChanged})") + return [TextContent(type="text", text=",".join(declared) or "none")] - async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + async def list_roots(context: RequestContext[ClientSession, Any]) -> types.ListRootsResult | types.ErrorData: """Registered only so the client declares the roots capability; never called.""" raise NotImplementedError - server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("abilities", {}) assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="none")])) @@ -230,24 +268,21 @@ async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: async def test_request_before_initialization_is_rejected() -> None: """A feature request sent before the handshake completes is rejected; ping is exempt. - Client always initializes on entry, so this drives a bare ClientSession that never sends - initialize. The server's stated reason for the rejection never reaches the client: the error - is reported as a generic invalid-params failure. + The `connect` fixture always initializes on entry, so this drives a bare ClientSession that + never sends initialize. The server's stated reason for the rejection never reaches the client: + the error is reported as a generic invalid-params failure. """ + server = Server("strict") - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + @server.list_tools() + async def list_tools() -> list[types.Tool]: """Registered so the request is routed to a real handler; never reached.""" raise NotImplementedError - server = Server("strict", on_list_tools=list_tools) - - async with ( - InMemoryTransport(server) as (read_stream, write_stream), - ClientSession(read_stream, write_stream) as session, - ): + async with _bare_session(server) as session: with anyio.fail_after(5): - with pytest.raises(MCPError) as exc_info: - await session.send_request(ListToolsRequest(), ListToolsResult) + with pytest.raises(McpError) as exc_info: + await session.send_request(ClientRequest(ListToolsRequest()), ListToolsResult) # Ping is explicitly permitted before initialization completes. pong = await session.send_ping() @@ -268,30 +303,26 @@ async def test_initialize_negotiates_protocol_version() -> None: """ server = Server("negotiator") - def initialize_request(protocol_version: str) -> InitializeRequest: - return InitializeRequest( - params=InitializeRequestParams( - protocolVersion=protocol_version, - capabilities=ClientCapabilities(), - clientInfo=Implementation(name="time-traveller", version="0.0.1"), + def initialize_request(protocol_version: str) -> ClientRequest: + return ClientRequest( + InitializeRequest( + params=InitializeRequestParams( + protocolVersion=protocol_version, + capabilities=ClientCapabilities(), + clientInfo=Implementation(name="time-traveller", version="0.0.1"), + ) ) ) - async with ( - InMemoryTransport(server) as (read_stream, write_stream), - ClientSession(read_stream, write_stream) as session, - ): + async with _bare_session(server) as session: with anyio.fail_after(5): result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) - assert result.protocol_version == snapshot("2025-03-26") + assert result.protocolVersion == snapshot("2025-03-26") - async with ( - InMemoryTransport(server) as (read_stream, write_stream), - ClientSession(read_stream, write_stream) as session, - ): + async with _bare_session(server) as session: with anyio.fail_after(5): result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) - assert result.protocol_version == snapshot("2025-11-25") + assert result.protocolVersion == snapshot("2025-11-25") @requirement("lifecycle:version:reject-unsupported") @@ -308,7 +339,7 @@ async def scripted_server(streams: MessageStream) -> None: server_read, server_write = streams message = await server_read.receive() assert isinstance(message, SessionMessage) - request = message.message + request = message.message.root assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( @@ -318,11 +349,13 @@ async def scripted_server(streams: MessageStream) -> None: ) await server_write.send( SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) ) ) ) @@ -353,7 +386,7 @@ async def scripted_server(streams: MessageStream) -> None: server_read, server_write = streams message = await server_read.receive() assert isinstance(message, SessionMessage) - request = message.message + request = message.message.root assert isinstance(request, JSONRPCRequest) assert request.method == "initialize" result = InitializeResult( @@ -363,11 +396,13 @@ async def scripted_server(streams: MessageStream) -> None: ) await server_write.send( SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) ) ) ) @@ -381,4 +416,4 @@ async def scripted_server(streams: MessageStream) -> None: with anyio.fail_after(5): initialize_result = await session.initialize() - assert initialize_result.protocol_version == snapshot("2025-06-18") + assert initialize_result.protocolVersion == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index c0449fea9b..31b563e015 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -4,8 +4,8 @@ notification routes to the standalone GET stream and is not guaranteed to arrive before the tool result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. -The collector still records every message it receives, so the snapshot also proves nothing else -was delivered. +The collector still records every message it receives, so the length assertion also proves nothing +else was delivered. The servers register the parent capability (resources/prompts) so that part of the spec's precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` @@ -14,16 +14,17 @@ alongside the fix that introduces capability gating. """ +from typing import Any + import anyio import pytest -from inline_snapshot import snapshot from mcp import types -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.types import ( - CallToolResult, PromptListChangedNotification, ResourceListChangedNotification, + ServerNotification, TextContent, ToolListChangedNotification, ) @@ -44,24 +45,26 @@ async def collect(message: IncomingMessage) -> None: received.append(message) seen.set() - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="install", inputSchema={"type": "object"})]) + server = Server("registry") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "install" - await ctx.session.send_tool_list_changed() - return CallToolResult(content=[TextContent(type="text", text="installed")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="install", inputSchema={"type": "object"})] - server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "install" + await server.request_context.session.send_tool_list_changed() + return [TextContent(type="text", text="installed")] async with connect(server, message_handler=collect) as client: await client.call_tool("install", {}) with anyio.fail_after(5): await seen.wait() - assert received == snapshot([ToolListChangedNotification()]) + assert len(received) == 1 + assert isinstance(received[0], ServerNotification) + assert isinstance(received[0].root, ToolListChangedNotification) @requirement("resources:list-changed") @@ -74,30 +77,31 @@ async def collect(message: IncomingMessage) -> None: received.append(message) seen.set() - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="mount", inputSchema={"type": "object"})]) + server = Server("registry") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="mount", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "mount" - await ctx.session.send_resource_list_changed() - return CallToolResult(content=[TextContent(type="text", text="mounted")]) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "mount" + await server.request_context.session.send_resource_list_changed() + return [TextContent(type="text", text="mounted")] - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListResourcesResult: + @server.list_resources() + async def list_resources() -> list[types.Resource]: """Registered so the resources capability is advertised; the client never lists resources.""" raise NotImplementedError - server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) - async with connect(server, message_handler=collect) as client: await client.call_tool("mount", {}) with anyio.fail_after(5): await seen.wait() - assert received == snapshot([ResourceListChangedNotification()]) + assert len(received) == 1 + assert isinstance(received[0], ServerNotification) + assert isinstance(received[0].root, ResourceListChangedNotification) @requirement("prompts:list-changed") @@ -110,27 +114,28 @@ async def collect(message: IncomingMessage) -> None: received.append(message) seen.set() - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="learn", inputSchema={"type": "object"})]) + server = Server("registry") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "learn" - await ctx.session.send_prompt_list_changed() - return CallToolResult(content=[TextContent(type="text", text="learned")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="learn", inputSchema={"type": "object"})] - async def list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "learn" + await server.request_context.session.send_prompt_list_changed() + return [TextContent(type="text", text="learned")] + + @server.list_prompts() + async def list_prompts() -> list[types.Prompt]: """Registered so the prompts capability is advertised; the client never lists prompts.""" raise NotImplementedError - server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) - async with connect(server, message_handler=collect) as client: await client.call_tool("learn", {}) with anyio.fail_after(5): await seen.wait() - assert received == snapshot([PromptListChangedNotification()]) + assert len(received) == 1 + assert isinstance(received[0], ServerNotification) + assert isinstance(received[0].root, PromptListChangedNotification) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index e2c2ba2612..ffc281ca27 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -8,17 +8,20 @@ import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.server import Server, ServerRequestContext +from mcp import McpError +from mcp.server import Server from mcp.types import ( INVALID_PARAMS, + ErrorData, + ListPromptsRequest, ListPromptsResult, + ListResourcesRequest, ListResourcesResult, - ListResourceTemplatesResult, + ListToolsRequest, ListToolsResult, + PaginatedRequestParams, Prompt, Resource, - ResourceTemplate, Tool, ) from tests.interaction._connect import Connect @@ -35,23 +38,24 @@ async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> N cursor = "page-2" seen_cursors: list[str | None] = [] - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - assert params is not None # the client always sends params, even without a cursor - seen_cursors.append(params.cursor) - if params.cursor is None: + server = Server("paginated") + + @server.list_tools() + async def list_tools(request: ListToolsRequest) -> ListToolsResult: + received = request.params.cursor if request.params is not None else None + seen_cursors.append(received) + if received is None: return ListToolsResult( tools=[Tool(name="alpha", inputSchema={"type": "object"})], nextCursor=cursor, ) return ListToolsResult(tools=[Tool(name="beta", inputSchema={"type": "object"})]) - server = Server("paginated", on_list_tools=list_tools) - async with connect(server) as client: first_page = await client.list_tools() - second_page = await client.list_tools(cursor=first_page.next_cursor) + second_page = await client.list_tools(params=PaginatedRequestParams(cursor=first_page.nextCursor)) - assert first_page.next_cursor == cursor + assert first_page.nextCursor == cursor assert seen_cursors == [None, cursor] assert [tool.name for tool in first_page.tools] == ["alpha"] assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", inputSchema={"type": "object"})])) @@ -67,25 +71,26 @@ async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: "page-3": ("gamma", None), } - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - assert params is not None - tool_name, next_cursor = pages[params.cursor] - return ListToolsResult(tools=[Tool(name=tool_name, inputSchema={"type": "object"})], nextCursor=next_cursor) + server = Server("paginated") - server = Server("paginated", on_list_tools=list_tools) + @server.list_tools() + async def list_tools(request: ListToolsRequest) -> ListToolsResult: + received = request.params.cursor if request.params is not None else None + tool_name, next_cursor = pages[received] + return ListToolsResult(tools=[Tool(name=tool_name, inputSchema={"type": "object"})], nextCursor=next_cursor) collected: list[str] = [] cursor: str | None = None requests_made = 0 async with connect(server) as client: while True: - result = await client.list_tools(cursor=cursor) + result = await client.list_tools(params=PaginatedRequestParams(cursor=cursor)) requests_made += 1 assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" collected.extend(tool.name for tool in result.tools) - if result.next_cursor is None: + if result.nextCursor is None: break - cursor = result.next_cursor + cursor = result.nextCursor assert collected == snapshot(["alpha", "beta", "gamma"]) assert requests_made == len(pages) @@ -108,25 +113,26 @@ async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes( } received_cursors: list[str | None] = [] - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - assert params is not None - received_cursors.append(params.cursor) - names, next_cursor = pages[params.cursor] + server = Server("paginated") + + @server.list_tools() + async def list_tools(request: ListToolsRequest) -> ListToolsResult: + received = request.params.cursor if request.params is not None else None + received_cursors.append(received) + names, next_cursor = pages[received] return ListToolsResult( tools=[Tool(name=name, inputSchema={"type": "object"}) for name in names], nextCursor=next_cursor ) - server = Server("paginated", on_list_tools=list_tools) - page_sizes: list[int] = [] cursor: str | None = None async with connect(server) as client: while True: - result = await client.list_tools(cursor=cursor) + result = await client.list_tools(params=PaginatedRequestParams(cursor=cursor)) page_sizes.append(len(result.tools)) - if result.next_cursor is None: + if result.nextCursor is None: break - cursor = result.next_cursor + cursor = result.nextCursor # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] @@ -141,16 +147,17 @@ async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. """ - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - assert params is not None - assert params.cursor == "never-issued" - raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + server = Server("paginated") - server = Server("paginated", on_list_tools=list_tools) + @server.list_tools() + async def list_tools(request: ListToolsRequest) -> ListToolsResult: + assert request.params is not None + assert request.params.cursor == "never-issued" + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown cursor: {request.params.cursor!r}")) async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.list_tools(cursor="never-issued") + with pytest.raises(McpError) as exc_info: + await client.list_tools(params=PaginatedRequestParams(cursor="never-issued")) assert exc_info.value.error.code == INVALID_PARAMS @@ -161,59 +168,25 @@ async def test_resources_list_supports_cursor_pagination(connect: Connect) -> No cursor = "page-2" seen_cursors: list[str | None] = [] - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourcesResult: - assert params is not None - seen_cursors.append(params.cursor) - if params.cursor is None: + server = Server("paginated") + + @server.list_resources() + async def list_resources(request: ListResourcesRequest) -> ListResourcesResult: + received = request.params.cursor if request.params is not None else None + seen_cursors.append(received) + if received is None: return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], nextCursor=cursor) return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) - server = Server("paginated", on_list_resources=list_resources) - async with connect(server) as client: first_page = await client.list_resources() - second_page = await client.list_resources(cursor=first_page.next_cursor) + second_page = await client.list_resources(params=PaginatedRequestParams(cursor=first_page.nextCursor)) - assert first_page.next_cursor == cursor + assert first_page.nextCursor == cursor assert seen_cursors == [None, cursor] assert [resource.name for resource in first_page.resources] == ["first"] assert [resource.name for resource in second_page.resources] == ["second"] - assert second_page.next_cursor is None - - -@requirement("resources:templates:pagination") -async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: - """resources/templates/list round-trips the cursor like every other list operation.""" - cursor = "page-2" - seen_cursors: list[str | None] = [] - - async def list_resource_templates( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourceTemplatesResult: - assert params is not None - seen_cursors.append(params.cursor) - if params.cursor is None: - return ListResourceTemplatesResult( - resourceTemplates=[ResourceTemplate(name="first", uriTemplate="users://{id}")], - nextCursor=cursor, - ) - return ListResourceTemplatesResult( - resourceTemplates=[ResourceTemplate(name="second", uriTemplate="teams://{id}")] - ) - - server = Server("paginated", on_list_resource_templates=list_resource_templates) - - async with connect(server) as client: - first_page = await client.list_resource_templates() - second_page = await client.list_resource_templates(cursor=first_page.next_cursor) - - assert first_page.next_cursor == cursor - assert seen_cursors == [None, cursor] - assert [template.name for template in first_page.resource_templates] == ["first"] - assert [template.name for template in second_page.resource_templates] == ["second"] - assert second_page.next_cursor is None + assert second_page.nextCursor is None @requirement("prompts:list:pagination") @@ -222,21 +195,22 @@ async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None cursor = "page-2" seen_cursors: list[str | None] = [] - async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: - assert params is not None - seen_cursors.append(params.cursor) - if params.cursor is None: + server = Server("paginated") + + @server.list_prompts() + async def list_prompts(request: ListPromptsRequest) -> ListPromptsResult: + received = request.params.cursor if request.params is not None else None + seen_cursors.append(received) + if received is None: return ListPromptsResult(prompts=[Prompt(name="first")], nextCursor=cursor) return ListPromptsResult(prompts=[Prompt(name="second")]) - server = Server("paginated", on_list_prompts=list_prompts) - async with connect(server) as client: first_page = await client.list_prompts() - second_page = await client.list_prompts(cursor=first_page.next_cursor) + second_page = await client.list_prompts(params=PaginatedRequestParams(cursor=first_page.nextCursor)) - assert first_page.next_cursor == cursor + assert first_page.nextCursor == cursor assert seen_cursors == [None, cursor] assert [prompt.name for prompt in first_page.prompts] == ["first"] assert [prompt.name for prompt in second_page.prompts] == ["second"] - assert second_page.next_cursor is None + assert second_page.nextCursor is None diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 8a525898d5..4202b7c6b7 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -1,28 +1,29 @@ """Resource interactions against the low-level Server, driven through the public Client API.""" -import base64 +from typing import Any import anyio import pytest from inline_snapshot import snapshot +from pydantic import AnyUrl -from mcp import MCPError, types -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.server.lowlevel import Server +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.types import ( METHOD_NOT_FOUND, Annotations, BlobResourceContents, - CallToolResult, EmptyResult, ErrorData, Icon, - ListResourcesResult, ListResourceTemplatesResult, ReadResourceResult, Resource, ResourceTemplate, ResourceUpdatedNotification, ResourceUpdatedNotificationParams, + ServerNotification, TextContent, TextResourceContents, ) @@ -40,68 +41,66 @@ async def test_list_resources_returns_registered_resources(connect: Connect) -> The fully-populated entry includes annotations, so the snapshot also proves they round-trip. The SDK's Annotations model omits the schema's lastModified field (see the divergence on - resources:annotations); the input is built via model_validate with lastModified set so the - snapshot pins the drop and will fail once the SDK adds the field. + resources:annotations) but allows extra fields, so lastModified round-trips as an undeclared + extra; the snapshot compares the serialised dict so the extra key is visible and pinned. """ - - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourcesResult: - return ListResourcesResult( - resources=[ - Resource(uri="memo://minimal", name="minimal"), - Resource( - uri="file:///project/README.md", - name="readme", - title="Project README", - description="The project's front page.", - mimeType="text/markdown", - size=1024, - annotations=Annotations.model_validate( - {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} - ), - icons=[Icon(src="https://example.com/readme.png", mimeType="image/png", sizes=["48x48"])], + server = Server("library") + + @server.list_resources() + async def list_resources() -> list[Resource]: + return [ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mimeType="text/markdown", + size=1024, + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} ), - ] - ) - - server = Server("library", on_list_resources=list_resources) + icons=[Icon(src="https://example.com/readme.png", mimeType="image/png", sizes=["48x48"])], + ), + ] async with connect(server) as client: result = await client.list_resources() - assert result == snapshot( - ListResourcesResult( - resources=[ - Resource(uri="memo://minimal", name="minimal"), - Resource( - uri="file:///project/README.md", - name="readme", - title="Project README", - description="The project's front page.", - mimeType="text/markdown", - size=1024, - annotations=Annotations(audience=["user", "assistant"], priority=0.8), - icons=[Icon(src="https://example.com/readme.png", mimeType="image/png", sizes=["48x48"])], - ), + assert result.model_dump(by_alias=True, exclude_none=True) == snapshot( + { + "resources": [ + {"name": "minimal", "uri": AnyUrl("memo://minimal")}, + { + "name": "readme", + "title": "Project README", + "uri": AnyUrl("file:///project/README.md"), + "description": "The project's front page.", + "mimeType": "text/markdown", + "size": 1024, + "icons": [{"src": "https://example.com/readme.png", "mimeType": "image/png", "sizes": ["48x48"]}], + "annotations": { + "audience": ["user", "assistant"], + "priority": 0.8, + "lastModified": "2025-01-01T00:00:00Z", + }, + }, ] - ) + } ) @requirement("resources:read:text") async def test_read_resource_text(connect: Connect) -> None: """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + server = Server("library") - async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: - return ReadResourceResult( - contents=[TextResourceContents(uri=params.uri, mimeType="text/plain", text="Hello, world!")] - ) - - server = Server("library", on_read_resource=read_resource) + @server.read_resource() + async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: + return [ReadResourceContents(content="Hello, world!", mime_type="text/plain")] async with connect(server) as client: - result = await client.read_resource("file:///greeting.txt") + result = await client.read_resource(AnyUrl("file:///greeting.txt")) assert result == snapshot( ReadResourceResult( @@ -112,23 +111,18 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:read:blob") async def test_read_resource_binary(connect: Connect) -> None: - """Reading a binary resource returns its contents base64-encoded in the blob field.""" - - async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: - return ReadResourceResult( - contents=[ - BlobResourceContents( - uri=params.uri, - mimeType="image/png", - blob=base64.b64encode(b"\x89PNG").decode(), - ) - ] - ) + """Reading a binary resource returns its contents base64-encoded in the blob field. - server = Server("library", on_read_resource=read_resource) + The low-level decorator base64-encodes the bytes returned via ``ReadResourceContents``. + """ + server = Server("library") + + @server.read_resource() + async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: + return [ReadResourceContents(content=b"\x89PNG", mime_type="image/png")] async with connect(server) as client: - result = await client.read_resource("file:///pixel.png") + result = await client.read_resource(AnyUrl("file:///pixel.png")) assert result == snapshot( ReadResourceResult( @@ -139,20 +133,20 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:read:unknown-uri") async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: - """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + """A handler that rejects an unrecognised URI with McpError produces a JSON-RPC error. The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches the client verbatim. """ + server = Server("library") - async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: - raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") - - server = Server("library", on_read_resource=read_resource) + @server.read_resource() + async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: + raise McpError(ErrorData(code=-32002, message=f"Resource not found: {uri}")) async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.read_resource("file:///missing.txt") + with pytest.raises(McpError) as exc_info: + await client.read_resource(AnyUrl("file:///missing.txt")) assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) @@ -160,25 +154,21 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:templates:list") async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" - - async def list_resource_templates( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourceTemplatesResult: - return ListResourceTemplatesResult( - resourceTemplates=[ - ResourceTemplate(uriTemplate="users://{user_id}", name="user"), - ResourceTemplate( - uriTemplate="logs://{service}/{date}", - name="service_logs", - title="Service logs", - description="One day of logs for one service.", - mimeType="text/plain", - icons=[Icon(src="https://example.com/logs.png", mimeType="image/png", sizes=["48x48"])], - ), - ] - ) - - server = Server("library", on_list_resource_templates=list_resource_templates) + server = Server("library") + + @server.list_resource_templates() + async def list_resource_templates() -> list[ResourceTemplate]: + return [ + ResourceTemplate(uriTemplate="users://{user_id}", name="user"), + ResourceTemplate( + uriTemplate="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mimeType="text/plain", + icons=[Icon(src="https://example.com/logs.png", mimeType="image/png", sizes=["48x48"])], + ), + ] async with connect(server) as client: result = await client.list_resource_templates() @@ -203,15 +193,14 @@ async def list_resource_templates( @requirement("resources:subscribe") async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + server = Server("library") - async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: - assert params.uri == "file:///watched.txt" - return EmptyResult() - - server = Server("library", on_subscribe_resource=subscribe_resource) + @server.subscribe_resource() + async def subscribe_resource(uri: AnyUrl) -> None: + assert uri == AnyUrl("file:///watched.txt") async with connect(server) as client: - result = await client.subscribe_resource("file:///watched.txt") + result = await client.subscribe_resource(AnyUrl("file:///watched.txt")) assert result == snapshot(EmptyResult()) @@ -224,17 +213,16 @@ async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect divergence on lifecycle:capability:server-not-advertised. """ - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourcesResult: + server = Server("library") + + @server.list_resources() + async def list_resources() -> list[Resource]: """Registered only so the resources capability is advertised; never called.""" raise NotImplementedError - server = Server("library", on_list_resources=list_resources) - async with connect(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.subscribe_resource("file:///watched.txt") + with pytest.raises(McpError) as exc_info: + await client.subscribe_resource(AnyUrl("file:///watched.txt")) assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) @@ -242,15 +230,14 @@ async def list_resources( @requirement("resources:unsubscribe") async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + server = Server("library") - async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: - assert params.uri == "file:///watched.txt" - return EmptyResult() - - server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + @server.unsubscribe_resource() + async def unsubscribe_resource(uri: AnyUrl) -> None: + assert uri == AnyUrl("file:///watched.txt") async with connect(server) as client: - result = await client.unsubscribe_resource("file:///watched.txt") + result = await client.unsubscribe_resource(AnyUrl("file:///watched.txt")) assert result == snapshot(EmptyResult()) @@ -271,39 +258,34 @@ async def collect(message: IncomingMessage) -> None: received.append(message) seen.set() - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="touch", inputSchema={"type": "object"})]) + server = Server("library") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "touch" - await ctx.session.send_resource_updated("file:///watched.txt") - return CallToolResult(content=[TextContent(type="text", text="touched")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="touch", inputSchema={"type": "object"})] - async def list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> ListResourcesResult: + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "touch" + await server.request_context.session.send_resource_updated(AnyUrl("file:///watched.txt")) + return [TextContent(type="text", text="touched")] + + @server.list_resources() + async def list_resources() -> list[Resource]: """Registered so the resources capability is advertised; the client never lists resources.""" raise NotImplementedError - async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + @server.subscribe_resource() + async def subscribe_resource(uri: AnyUrl) -> None: """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" raise NotImplementedError - server = Server( - "library", - on_list_tools=list_tools, - on_call_tool=call_tool, - on_list_resources=list_resources, - on_subscribe_resource=subscribe_resource, - ) - async with connect(server, message_handler=collect) as client: await client.call_tool("touch", {}) with anyio.fail_after(5): await seen.wait() - assert received == snapshot( - [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] - ) + assert len(received) == 1 + assert isinstance(received[0], ServerNotification) + assert isinstance(received[0].root, ResourceUpdatedNotification) + assert received[0].root.params == snapshot(ResourceUpdatedNotificationParams(uri=AnyUrl("file:///watched.txt"))) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 5bebd9158d..14ac80d46e 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -1,13 +1,15 @@ -"""Roots interactions against the low-level Server, driven through the public Client API.""" +"""Roots interactions against the low-level Server, driven through the public ClientSession API.""" + +from typing import Any -import anyio import pytest from inline_snapshot import snapshot from pydantic import FileUrl -from mcp import MCPError, types -from mcp.client import ClientRequestContext -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.context import RequestContext from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -21,21 +23,20 @@ async def test_list_roots_round_trip(connect: Connect) -> None: The tool reports the URIs and names it received, proving the client's roots reached the server. """ + server = Server("rooted") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="show_roots", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "show_roots" - result = await ctx.session.list_roots() + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "show_roots" + result = await server.request_context.session.list_roots() lines = [f"{root.uri} name={root.name}" for root in result.roots] - return CallToolResult(content=[TextContent(type="text", text="\n".join(lines))]) - - server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text="\n".join(lines))] - async def list_roots(context: ClientRequestContext) -> ListRootsResult: + async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsResult | ErrorData: return ListRootsResult( roots=[ Root(uri=FileUrl("file:///home/alice/project"), name="project"), @@ -60,20 +61,19 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: @requirement("roots:list:empty") async def test_list_roots_empty(connect: Connect) -> None: """A client with no roots to offer answers roots/list with an empty list, not an error.""" + server = Server("rooted") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="count_roots", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="count_roots", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "count_roots" - result = await ctx.session.list_roots() - return CallToolResult(content=[TextContent(type="text", text=str(len(result.roots)))]) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "count_roots" + result = await server.request_context.session.list_roots() + return [TextContent(type="text", text=str(len(result.roots)))] - server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) - - async def list_roots(context: ClientRequestContext) -> ListRootsResult: + async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsResult | ErrorData: return ListRootsResult(roots=[]) async with connect(server, list_roots_callback=list_roots) as client: @@ -89,22 +89,21 @@ async def test_list_roots_without_callback_is_error(connect: Connect) -> None: The client's default callback answers with INVALID_REQUEST rather than leaving the server hanging; the spec names -32601 for this case (see the divergence note on the requirement). """ + server = Server("rooted") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="show_roots", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "show_roots" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "show_roots" try: - await ctx.session.list_roots() - except MCPError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) + await server.request_context.session.list_roots() + except McpError as exc: + return [TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")] raise NotImplementedError # list_roots cannot succeed without a client callback - server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("show_roots", {}) @@ -117,56 +116,27 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: """A roots callback that answers with an error fails the roots/list request with that exact error. - The callback's code and message reach the requesting handler verbatim as an MCPError. + The callback's code and message reach the requesting handler verbatim as a McpError. """ + server = Server("rooted") - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="show_roots", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="show_roots", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "show_roots" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "show_roots" try: - await ctx.session.list_roots() - except MCPError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) + await server.request_context.session.list_roots() + except McpError as exc: + return [TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")] raise NotImplementedError # the callback always answers with an error - server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) - - async def list_roots(context: ClientRequestContext) -> ErrorData: + async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsResult | ErrorData: return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("show_roots", {}) assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="-32603: roots provider crashed")])) - - -@requirement("roots:list-changed") -async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: - """A roots/list_changed notification from the client is delivered to the server's handler. - - Unlike a request, a notification has no response to await: the handler sets an event and the - test waits on it, which is the only synchronisation point proving delivery. - """ - delivered = anyio.Event() - received: list[types.NotificationParams | None] = [] - - async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: - received.append(params) - delivered.set() - - server = Server("rooted", on_roots_list_changed=roots_list_changed) - - async def list_roots(context: ClientRequestContext) -> ListRootsResult: - """Registered so the client declares the roots capability; the server never asks for roots.""" - raise NotImplementedError - - async with connect(server, list_roots_callback=list_roots) as client: - await client.send_roots_list_changed() - with anyio.fail_after(5): - await delivered.wait() - - assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 3ee68f5d9e..60e09a4fd9 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -5,13 +5,16 @@ round-trips what it received back to the test through its tool result. """ +from typing import Any + import pydantic import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.client import ClientRequestContext -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.context import RequestContext from mcp.types import ( AudioContent, CallToolResult, @@ -42,27 +45,25 @@ async def test_create_message_round_trip(connect: Connect) -> None: """ received: list[CreateMessageRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" - result = await ctx.session.create_message( + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" + result = await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult( - content=[TextContent(type="text", text=f"{result.model}/{result.stop_reason}: {result.content.text}")] - ) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=f"{result.model}/{result.stopReason}: {result.content.text}")] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: received.append(params) return CreateMessageResult( role="assistant", @@ -80,7 +81,6 @@ async def sampling_callback( assert received == snapshot( [ CreateMessageRequestParams( - _meta={}, messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], maxTokens=100, ) @@ -101,14 +101,16 @@ async def test_create_message_params_reach_callback(connect: Connect) -> None: """ received: list[CreateMessageRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" - result = await ctx.session.create_message( + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" + result = await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Pick a model."))], max_tokens=50, system_prompt="You are terse.", @@ -123,13 +125,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ), ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=result.content.text)] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: received.append(params) return CreateMessageResult(role="assistant", content=TextContent(type="text", text="ok"), model="mock-llm-1") @@ -140,7 +140,6 @@ async def sampling_callback( assert received == snapshot( [ CreateMessageRequestParams( - _meta={}, messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Pick a model."))], modelPreferences=ModelPreferences( hints=[ModelHint(name="claude"), ModelHint(name="gpt")], @@ -152,7 +151,13 @@ async def sampling_callback( includeContext="thisServer", temperature=0.7, maxTokens=50, - stopSequences=["\n\n", "END"], + stopSequences=[ + """\ + + +""", + "END", + ], ) ] ) @@ -167,33 +172,33 @@ async def test_create_message_request_with_image_content_reaches_callback(connec """ received: list[CreateMessageRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="describe_image", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "describe_image" - result = await ctx.session.create_message( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="describe_image", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "describe_image" + result = await server.request_context.session.create_message( messages=[ SamplingMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")) ], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=result.content.text)] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: received.append(params) image = params.messages[0].content assert isinstance(image, ImageContent) return CreateMessageResult( role="assistant", - content=TextContent(type="text", text=f"described {image.mime_type} ({image.data})"), + content=TextContent(type="text", text=f"described {image.mimeType} ({image.data})"), model="mock-vision-1", ) @@ -204,7 +209,6 @@ async def sampling_callback( assert received == snapshot( [ CreateMessageRequestParams( - _meta={}, messages=[ SamplingMessage(role="user", content=ImageContent(type="image", data="aW1n", mimeType="image/png")) ], @@ -221,28 +225,26 @@ async def test_create_message_result_with_image_content_returns_to_handler(conne This is the client-to-server direction: the model's response is an image rather than text. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="draw", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "draw" - result = await ctx.session.create_message( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="draw", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "draw" + result = await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Draw a cat."))], max_tokens=100, ) image = result.content assert isinstance(image, ImageContent) - return CallToolResult( - content=[TextContent(type="text", text=f"{result.model}: {image.mime_type} {image.data}")] - ) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=f"{result.model}: {image.mimeType} {image.data}")] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: return CreateMessageResult( role="assistant", content=ImageContent(type="image", data="Y2F0", mimeType="image/png"), @@ -257,31 +259,33 @@ async def sampling_callback( @requirement("sampling:error:user-rejected") async def test_create_message_callback_error(connect: Connect) -> None: - """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + """A sampling callback that answers with an error surfaces to the requesting handler as a McpError. The error here is the spec's own example for a user rejecting a sampling request (code -1); the callback's code and message reach the handler verbatim, whatever they are. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) - except MCPError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) + except McpError as exc: + return [TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")] raise NotImplementedError # the callback always answers with an error - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - - async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + async def sampling_callback( + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: return ErrorData(code=-1, message="User rejected sampling request") async with connect(server, sampling_callback=sampling_callback) as client: @@ -296,24 +300,24 @@ async def sampling_callback(context: ClientRequestContext, params: CreateMessage async def test_create_message_without_callback_is_error(connect: Connect) -> None: """A sampling request to a client with no sampling callback fails with the SDK's default error.""" - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello."))], max_tokens=100, ) - except MCPError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) + except McpError as exc: + return [TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")] raise NotImplementedError # create_message cannot succeed without a client callback - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async with connect(server) as client: result = await client.call_tool("ask_model", {}) @@ -324,32 +328,33 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. - The client supports plain sampling but cannot declare the tools sub-capability (Client does not - expose it), so the server-side validator rejects the request before anything reaches the wire. + The client supports plain sampling but has not declared the tools sub-capability (the connect + helper does not pass sampling_capabilities through), so the server-side validator rejects the + request before anything reaches the wire. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="What is the weather?"))], max_tokens=100, tools=[types.Tool(name="get_weather", inputSchema={"type": "object"})], ) - except MCPError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")]) + except McpError as exc: + return [TextContent(type="text", text=f"{exc.error.code}: {exc.error.message}")] raise NotImplementedError # the validator rejects every tool-enabled request - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: """Declares the plain sampling capability; never invoked because the request is rejected first.""" raise NotImplementedError @@ -372,15 +377,17 @@ async def test_create_message_with_mixed_tool_result_content_is_rejected(connect ValueError directly. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="summarise_tools", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "summarise_tools" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "summarise_tools" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[ SamplingMessage( role="user", @@ -395,14 +402,12 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara max_tokens=100, ) except ValueError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")]) + return [TextContent(type="text", text=f"{type(exc).__name__}: {exc}")] raise NotImplementedError # the validator rejects the malformed messages before sending - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: """Declares the sampling capability; never invoked because the request is rejected first.""" raise NotImplementedError @@ -425,27 +430,28 @@ async def sampling_callback( async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: """A client connecting with a sampling callback advertises the sampling capability to the server. - Client cannot declare any sub-capabilities (it does not expose ClientSession's - sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + The connect helper does not pass sampling_capabilities through to ClientSession, so the + snapshot pins an empty SamplingCapability with no sub-capabilities. """ captured: list[SamplingCapability | None] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="capabilities", inputSchema={"type": "object"})]) + server = Server("introspector") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "capabilities" - assert ctx.session.client_params is not None - captured.append(ctx.session.client_params.capabilities.sampling) - return CallToolResult(content=[TextContent(type="text", text="ok")]) + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="capabilities", inputSchema={"type": "object"})] - server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "capabilities" + session = server.request_context.session + assert session.client_params is not None + captured.append(session.client_params.capabilities.sampling) + return [TextContent(type="text", text="ok")] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: """Registered only so the sampling capability is advertised; never called.""" raise NotImplementedError @@ -464,33 +470,33 @@ async def test_create_message_request_with_audio_content_reaches_callback(connec """ received: list[CreateMessageRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="transcribe", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="transcribe", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "transcribe" - result = await ctx.session.create_message( + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "transcribe" + result = await server.request_context.session.create_message( messages=[ SamplingMessage(role="user", content=AudioContent(type="audio", data="c25k", mimeType="audio/wav")) ], max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=result.content.text)] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: received.append(params) audio = params.messages[0].content assert isinstance(audio, AudioContent) return CreateMessageResult( role="assistant", - content=TextContent(type="text", text=f"transcribed {audio.mime_type} ({audio.data})"), + content=TextContent(type="text", text=f"transcribed {audio.mimeType} ({audio.data})"), model="mock-audio-1", ) @@ -501,7 +507,6 @@ async def sampling_callback( assert received == snapshot( [ CreateMessageRequestParams( - _meta={}, messages=[ SamplingMessage(role="user", content=AudioContent(type="audio", data="c25k", mimeType="audio/wav")) ], @@ -518,28 +523,26 @@ async def test_create_message_result_with_audio_content_returns_to_handler(conne This is the client-to-server direction: the model's response is audio rather than text. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="speak", inputSchema={"type": "object"})]) + server = Server("sampler") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="speak", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "speak" - result = await ctx.session.create_message( + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "speak" + result = await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello, aloud."))], max_tokens=100, ) audio = result.content assert isinstance(audio, AudioContent) - return CallToolResult( - content=[TextContent(type="text", text=f"{result.model}: {audio.mime_type} {audio.data}")] - ) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=f"{result.model}: {audio.mimeType} {audio.data}")] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: return CreateMessageResult( role="assistant", content=AudioContent(type="audio", data="aGVsbG8=", mimeType="audio/wav"), @@ -559,14 +562,16 @@ async def test_create_message_with_list_valued_message_content_reaches_callback( """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" received: list[CreateMessageRequestParams] = [] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="caption", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "caption" - result = await ctx.session.create_message( + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="caption", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "caption" + result = await server.request_context.session.create_message( messages=[ SamplingMessage( role="user", @@ -579,13 +584,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara max_tokens=100, ) assert isinstance(result.content, TextContent) - return CallToolResult(content=[TextContent(type="text", text=result.content.text)]) - - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + return [TextContent(type="text", text=result.content.text)] async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: received.append(params) content = params.messages[0].content assert isinstance(content, list) @@ -600,7 +603,6 @@ async def sampling_callback( assert received == snapshot( [ CreateMessageRequestParams( - _meta={}, messages=[ SamplingMessage( role="user", @@ -625,15 +627,17 @@ async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejecte client-side -32602 check is tracked separately at sampling:tool-use:result-balance. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="continue_tools", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "continue_tools" + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="continue_tools", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "continue_tools" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[ SamplingMessage( role="assistant", @@ -653,14 +657,12 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara max_tokens=100, ) except ValueError as exc: - return CallToolResult(content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")]) + return [TextContent(type="text", text=f"{type(exc).__name__}: {exc}")] raise NotImplementedError # the validator rejects the malformed messages before sending - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResult: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: """Declares the sampling capability; never invoked because the request is rejected first.""" raise NotImplementedError @@ -688,27 +690,27 @@ async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_valida the result; instead the client accepts it and the server's response parsing raises. """ - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="ask_model", inputSchema={"type": "object"})]) + server = Server("sampler") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "ask_model" + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="ask_model", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "ask_model" try: - await ctx.session.create_message( + await server.request_context.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Two thoughts, please."))], max_tokens=100, ) except pydantic.ValidationError as exc: - return CallToolResult(content=[TextContent(type="text", text=type(exc).__name__)]) + return [TextContent(type="text", text=type(exc).__name__)] raise NotImplementedError # the array-content result fails server-side parsing every time - server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async def sampling_callback( - context: ClientRequestContext, params: CreateMessageRequestParams - ) -> CreateMessageResultWithTools: + context: RequestContext[ClientSession, Any], params: CreateMessageRequestParams + ) -> CreateMessageResult | CreateMessageResultWithTools | ErrorData: return CreateMessageResultWithTools( role="assistant", content=[TextContent(type="text", text="First thought."), TextContent(type="text", text="Second thought.")], diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index c08fc02dd2..2a3b885a6d 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -1,4 +1,4 @@ -"""Request timeouts against the low-level Server, driven through the public Client API. +"""Request timeouts against the low-level Server, driven through the public client API. The handler blocks on an event that is never set, so the awaited response can never arrive and any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore @@ -6,14 +6,17 @@ cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) """ +from datetime import timedelta +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.client.client import Client -from mcp.server import Server, ServerRequestContext -from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from mcp import McpError, types +from mcp.server.lowlevel import Server +from mcp.types import CallToolResult, ErrorData, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -21,7 +24,7 @@ @requirement("protocol:timeout:basic") @requirement("protocol:timeout:sends-cancellation") -async def test_request_timeout_fails_the_pending_call() -> None: +async def test_request_timeout_fails_the_pending_call(connect: Connect) -> None: """A request whose response does not arrive within its read timeout fails with a timeout error. No cancellation is sent to the server (see the divergence note on the requirement): the handler @@ -29,56 +32,50 @@ async def test_request_timeout_fails_the_pending_call() -> None: handler to have started only after the timeout has fired, so the timeout itself races nothing. """ handler_started = anyio.Event() + server: Server[Any] = Server("blocker") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "block" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "block" handler_started.set() await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable - server = Server("blocker", on_call_tool=call_tool) - - async with Client(server) as client: - with pytest.raises(MCPError) as exc_info: - await client.call_tool("block", {}, read_timeout_seconds=0.000001) + async with connect(server) as client: + with pytest.raises(McpError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=timedelta(seconds=0.000001)) # The request was already on the wire: the handler still runs even though the caller gave up. with anyio.fail_after(5): await handler_started.wait() assert exc_info.value.error == snapshot( - ErrorData( - code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", - ) + ErrorData(code=408, message="Timed out while waiting for response to ClientRequest. Waited 1e-06 seconds.") ) @requirement("protocol:timeout:session-survives") -async def test_session_serves_requests_after_timeout() -> None: +async def test_session_serves_requests_after_timeout(connect: Connect) -> None: """A timed-out request does not poison the session: the next request succeeds.""" - - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool(name="block", inputSchema={"type": "object"}), - types.Tool(name="echo", inputSchema={"type": "object"}), - ] - ) - - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - if params.name == "echo": - return CallToolResult(content=[TextContent(type="text", text="still alive")]) + server: Server[Any] = Server("blocker") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool(name="block", inputSchema={"type": "object"}), + types.Tool(name="echo", inputSchema={"type": "object"}), + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + if name == "echo": + return [TextContent(type="text", text="still alive")] await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable - server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) - - async with Client(server) as client: - with pytest.raises(MCPError): - await client.call_tool("block", {}, read_timeout_seconds=0.000001) + async with connect(server) as client: + with pytest.raises(McpError): + await client.call_tool("block", {}, read_timeout_seconds=timedelta(seconds=0.000001)) result = await client.call_tool("echo", {}) @@ -86,29 +83,26 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:timeout:session-default") -async def test_session_level_timeout_applies_to_every_request() -> None: +async def test_session_level_timeout_applies_to_every_request(connect: Connect) -> None: """A read timeout configured on the client applies to requests that do not set their own.""" + server: Server[Any] = Server("blocker") - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "block" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "block" await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable - server = Server("blocker", on_call_tool=call_tool) - # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the # per-request timeouts: a session-level timeout also governs the initialize handshake, so the # value must be long enough for the in-process handshake to complete before the blocked tool # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual # latency; lowering it only erodes the margin against CI scheduler jitter without saving # anything perceptible. - async with Client(server, read_timeout_seconds=0.05) as client: - with pytest.raises(MCPError) as exc_info: + async with connect(server, read_timeout_seconds=timedelta(seconds=0.05)) as client: + with pytest.raises(McpError) as exc_info: await client.call_tool("block", {}) assert exc_info.value.error == snapshot( - ErrorData( - code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", - ) + ErrorData(code=408, message="Timed out while waiting for response to ClientRequest. Waited 0.05 seconds.") ) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index f216884ecf..ff1bd9df6e 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -10,15 +10,18 @@ malformed JSON-RPC requests that the typed client API cannot produce. """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + import anyio import pytest from inline_snapshot import snapshot -from mcp import MCPError, types -from mcp.client import ClientRequestContext, ClientSession -from mcp.client._memory import InMemoryTransport -from mcp.client.client import Client -from mcp.server import Server, ServerRequestContext +from mcp import McpError, types +from mcp.client.session import ClientSession, ListRootsFnT +from mcp.server.lowlevel import Server +from mcp.shared.context import RequestContext from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( @@ -27,34 +30,60 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - EmptyResult, + ClientRequest, ErrorData, JSONRPCError, + JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, ListRootsResult, TextContent, ) -from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._helpers import Recording, _RecordingReadStream from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio -def _echo_server() -> Server: +def _echo_server() -> Server[Any]: """A server with one echo tool, used by every test in this module.""" + server: Server[Any] = Server("wire") + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool(name="echo", inputSchema={"type": "object"})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "echo" + return [TextContent(type="text", text="ok")] - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult(tools=[types.Tool(name="echo", inputSchema={"type": "object"})]) + return server - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "echo" - return CallToolResult(content=[TextContent(type="text", text="ok")]) - return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) +@asynccontextmanager +async def _record( + server: Server[Any], *, list_roots_callback: ListRootsFnT | None = None +) -> AsyncIterator[tuple[ClientSession, Recording]]: + """Connect a `ClientSession` to `server` over in-memory streams wrapped in a `Recording`. + + The yielded session is initialized; the recording captures the full handshake plus everything + the test does after. v1 has no `Transport` abstraction, so the recording is inserted between + the raw memory-stream pair and the `ClientSession`. + """ + async with create_client_server_memory_streams() as ((client_read, client_write), (server_read, server_write)): + recording = Recording(client_read, client_write) + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options())) + try: + async with ClientSession( + recording.read, recording.write, list_roots_callback=list_roots_callback + ) as client: + await client.initialize() + yield client, recording + finally: + tg.cancel_scope.cancel() @requirement("protocol:request-id:unique") @@ -63,20 +92,17 @@ async def test_request_ids_are_unique_and_never_null() -> None: The id sequence is pinned: sequential integers from zero, in send order. """ - recording = RecordingTransport(InMemoryTransport(_echo_server())) - - async with Client(recording) as client: + async with _record(_echo_server()) as (client, recording): await client.list_tools() await client.call_tool("echo", {}) await client.call_tool("echo", {}) await client.send_ping() - sent = [message.message for message in recording.sent] + sent = [message.message.root for message in recording.sent] request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] assert all(request_id is not None for request_id in request_ids) assert len(request_ids) == len(set(request_ids)) - # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a - # schema-cache refresh here because the explicit tools/list already populated the cache. + # initialize, tools/list, tools/call, tools/call, ping assert request_ids == snapshot([0, 1, 2, 3, 4]) @@ -89,20 +115,18 @@ async def test_notifications_are_never_answered() -> None: the id of the request it answers, and nothing else. """ - async def list_roots(context: ClientRequestContext) -> ListRootsResult: + async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsResult | ErrorData: """Registered so the client declares the roots capability; the server never asks for roots.""" raise NotImplementedError - recording = RecordingTransport(InMemoryTransport(_echo_server())) - - async with Client(recording, list_roots_callback=list_roots) as client: + async with _record(_echo_server(), list_roots_callback=list_roots) as (client, recording): await client.send_roots_list_changed() await client.send_ping() - sent = [message.message for message in recording.sent] + sent = [message.message.root for message in recording.sent] sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] - received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received = [message.message.root for message in recording.received if isinstance(message, SessionMessage)] received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed @@ -132,15 +156,13 @@ async def test_exactly_one_initialized_notification_is_sent_after_the_handshake( The full method sequence the client puts on the wire is pinned in send order. """ - recording = RecordingTransport(InMemoryTransport(_echo_server())) - - async with Client(recording) as client: + async with _record(_echo_server()) as (client, recording): await client.list_tools() sent_methods = [ - message.message.method + message.message.root.method for message in recording.sent - if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + if isinstance(message.message.root, JSONRPCRequest | JSONRPCNotification) ] assert sent_methods.count("notifications/initialized") == 1 assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) @@ -157,14 +179,15 @@ async def test_closing_the_transport_fails_in_flight_requests_with_connection_cl """ handler_started = anyio.Event() - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "block" + server: Server[Any] = Server("blocker") + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + assert name == "block" handler_started.set() await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event raise NotImplementedError # unreachable: the wait above never completes normally - server = Server("blocker", on_call_tool=call_tool) - async with create_client_server_memory_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams @@ -178,9 +201,10 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await session.initialize() async def call_and_capture_error() -> None: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await session.send_request( - CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ClientRequest(CallToolRequest(params=CallToolRequestParams(name="block"))), + CallToolResult, ) errors.append(exc_info.value.error) @@ -216,32 +240,38 @@ async def test_malformed_request_params_are_answered_with_invalid_params() -> No with anyio.fail_after(5): await client_write.send( SessionMessage( - JSONRPCRequest( - jsonrpc="2.0", - id=0, - method="initialize", - params={ - "protocolVersion": "2025-11-25", - "capabilities": {}, - "clientInfo": {"name": "raw", "version": "0.0.1"}, - }, + JSONRPCMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) ) ) ) init_response = await client_read.receive() assert isinstance(init_response, SessionMessage) - assert isinstance(init_response.message, JSONRPCResponse) + assert isinstance(init_response.message.root, JSONRPCResponse) await client_write.send( - SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + SessionMessage( + JSONRPCMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) ) await client_write.send( - SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + SessionMessage( + JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) ) error_response = await client_read.receive() assert isinstance(error_response, SessionMessage) - assert isinstance(error_response.message, JSONRPCError) - errors.append(error_response.message.error) + assert isinstance(error_response.message.root, JSONRPCError) + errors.append(error_response.message.root.error) server_task_group.cancel_scope.cancel() @@ -257,11 +287,13 @@ async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_par against a real Server. Reserve this pattern for behaviour the typed API cannot produce. """ - async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + server: Server[Any] = Server("logger") + + @server.set_logging_level() + async def set_logging_level(level: types.LoggingLevel) -> None: """Registered so the logging capability is advertised; never called -- params validation fails first.""" raise NotImplementedError - server = Server("logger", on_set_logging_level=set_logging_level) errors: list[ErrorData] = [] async with create_client_server_memory_streams() as (client_streams, server_streams): @@ -274,34 +306,40 @@ async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelReq with anyio.fail_after(5): await client_write.send( SessionMessage( - JSONRPCRequest( - jsonrpc="2.0", - id=0, - method="initialize", - params={ - "protocolVersion": "2025-11-25", - "capabilities": {}, - "clientInfo": {"name": "raw", "version": "0.0.1"}, - }, + JSONRPCMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) ) ) ) init_response = await client_read.receive() assert isinstance(init_response, SessionMessage) - assert isinstance(init_response.message, JSONRPCResponse) + assert isinstance(init_response.message.root, JSONRPCResponse) await client_write.send( - SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + SessionMessage( + JSONRPCMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) ) await client_write.send( SessionMessage( - JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + JSONRPCMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) ) ) error_response = await client_read.receive() assert isinstance(error_response, SessionMessage) - assert isinstance(error_response.message, JSONRPCError) - errors.append(error_response.message.error) + assert isinstance(error_response.message.root, JSONRPCError) + errors.append(error_response.message.root.error) server_task_group.cancel_scope.cancel() diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index 0ead93bf6c..f0d774cf07 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -1,13 +1,16 @@ -"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" +"""The Context convenience methods FastMCP injects into tool functions, observed from the client.""" + +from typing import Any import pytest from inline_snapshot import snapshot from pydantic import BaseModel -from mcp import MCPError -from mcp.client import ClientRequestContext +from mcp import McpError +from mcp.client.session import ClientSession from mcp.server.elicitation import AcceptedElicitation -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.context import RequestContext from mcp.types import ( METHOD_NOT_FOUND, CallToolResult, @@ -16,8 +19,8 @@ ElicitResult, ErrorData, Implementation, - LoggingMessageNotification, LoggingMessageNotificationParams, + ServerNotification, TextContent, ) from tests.interaction._connect import Connect @@ -37,7 +40,7 @@ async def test_context_logging_helpers_send_log_notifications(connect: Connect) advertising the logging capability (see the divergence note on logging:capability). """ received: list[LoggingMessageNotificationParams] = [] - mcp = MCPServer("chatty") + mcp = FastMCP("chatty") @mcp.tool() async def narrate(ctx: Context) -> str: @@ -52,7 +55,9 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async with connect(mcp, logging_callback=collect) as client: result = await client.call_tool("narrate", {}) - advertised_logging = client.initialize_result.capabilities.logging + capabilities = client.get_server_capabilities() + assert capabilities is not None + advertised_logging = capabilities.logging assert result == snapshot( CallToolResult(content=[TextContent(type="text", text="done")], structuredContent={"result": "done"}) @@ -76,7 +81,7 @@ async def test_context_report_progress_sends_progress_notifications(connect: Con The caller's progress callback receives each report, in order, before the tool call returns. """ received: list[tuple[float, float | None, str | None]] = [] - mcp = MCPServer("worker") + mcp = FastMCP("worker") @mcp.tool() async def crunch(ctx: Context) -> str: @@ -104,13 +109,13 @@ async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Con test asserts the value the tool saw is the one returned, rather than pinning the literal); the client info reflects what the caller passed to `Client`. """ - mcp = MCPServer("introspector") + mcp = FastMCP("introspector") @mcp.tool() async def whoami(ctx: Context) -> str: - client_params = ctx.session.client_params + client_params = ctx.request_context.session.client_params assert client_params is not None - return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + return f"request {ctx.request_id} from {client_params.clientInfo.name} {client_params.clientInfo.version}" async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: result = await client.call_tool("whoami", {}) @@ -132,7 +137,7 @@ async def test_report_progress_without_a_progress_token_sends_nothing(connect: C token-less request. """ received: list[IncomingMessage] = [] - mcp = MCPServer("quiet") + mcp = FastMCP("quiet") @mcp.tool() async def mill(ctx: Context) -> str: @@ -149,9 +154,9 @@ async def collect(message: IncomingMessage) -> None: assert result == snapshot( CallToolResult(content=[TextContent(type="text", text="milled")], structuredContent={"result": "milled"}) ) - assert received == snapshot( - [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] - ) + notification_params = [msg.root.params for msg in received if isinstance(msg, ServerNotification)] + assert len(notification_params) == len(received) + assert notification_params == snapshot([LoggingMessageNotificationParams(level="info", data="milling done")]) @requirement("mcpserver:context:elicit") @@ -163,19 +168,20 @@ async def test_context_elicit_returns_typed_result(connect: Connect) -> None: back into the model and handed to the tool as result.data. """ received: list[ElicitRequestParams] = [] - mcp = MCPServer("travel") + mcp = FastMCP("travel") class TravelPreferences(BaseModel): destination: str window_seat: bool @mcp.tool() - async def book_flight(ctx: Context) -> str: + async def book_flight() -> str: + ctx = mcp.get_context() answer = await ctx.elicit("Where to?", TravelPreferences) assert isinstance(answer, AcceptedElicitation) return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" - async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + async def answer_form(context: RequestContext[ClientSession, Any], params: ElicitRequestParams) -> ElicitResult: received.append(params) return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) @@ -185,7 +191,6 @@ async def answer_form(context: ClientRequestContext, params: ElicitRequestParams assert received == snapshot( [ ElicitRequestFormParams( - _meta={}, message="Where to?", requestedSchema={ "properties": { @@ -214,7 +219,7 @@ async def test_context_read_resource_reads_registered_resource(connect: Connect) The tool reports the MIME type and content it read, proving the resource function ran and its return value came back through the context. """ - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("config://app") def app_config() -> str: @@ -239,15 +244,15 @@ async def show_config(ctx: Context) -> str: @requirement("logging:message:filtered") async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: - """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + """FastMCP does not support logging/setLevel, so log messages are never filtered by severity. - The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + The request is rejected with METHOD_NOT_FOUND because FastMCP registers no handler for it, and every message a tool emits is delivered regardless of level. The spec says the server should only send messages at or above the configured level; with no way to configure one, everything is sent. """ received: list[LoggingMessageNotificationParams] = [] - mcp = MCPServer("unfilterable") + mcp = FastMCP("unfilterable") @mcp.tool() async def chatter(ctx: Context) -> str: @@ -259,7 +264,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) async with connect(mcp, logging_callback=collect) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.set_logging_level("error") await client.call_tool("chatter", {}) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index 5066095541..35dabee127 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -1,10 +1,11 @@ -"""Resource interactions against MCPServer, driven through the public Client API.""" +"""Resource interactions against FastMCP, driven through the public Client API.""" import pytest from inline_snapshot import snapshot +from pydantic import AnyUrl -from mcp import MCPError -from mcp.server.mcpserver import MCPServer +from mcp import McpError +from mcp.server.fastmcp import FastMCP from mcp.types import ( ErrorData, ListResourcesResult, @@ -23,7 +24,7 @@ @requirement("mcpserver:resource:static") async def test_read_static_resource(connect: Connect) -> None: """A function registered for a fixed URI is served at that URI with its return value as text.""" - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("config://app") def app_config() -> str: @@ -31,7 +32,7 @@ def app_config() -> str: return "theme = dark" async with connect(mcp) as client: - result = await client.read_resource("config://app") + result = await client.read_resource(AnyUrl("config://app")) assert result == snapshot( ReadResourceResult( @@ -47,7 +48,7 @@ async def test_list_static_and_templated_resources(connect: Connect) -> None: The name and description are derived from the function name and docstring; the MIME type defaults to text/plain. """ - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("config://app") def app_config() -> str: @@ -93,7 +94,7 @@ def user_profile(user_id: str) -> str: @requirement("resources:read:template-vars") async def test_read_templated_resource(connect: Connect) -> None: """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("users://{user_id}/profile") def user_profile(user_id: str) -> str: @@ -101,7 +102,7 @@ def user_profile(user_id: str) -> str: return f"profile for {user_id}" async with connect(mcp) as client: - result = await client.read_resource("users://42/profile") + result = await client.read_resource(AnyUrl("users://42/profile")) assert result == snapshot( ReadResourceResult( @@ -116,7 +117,7 @@ async def test_read_unknown_uri_is_error(connect: Connect) -> None: The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. """ - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("config://app") def app_config() -> str: @@ -124,8 +125,8 @@ def app_config() -> str: raise NotImplementedError async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: - await client.read_resource("config://missing") + with pytest.raises(McpError) as exc_info: + await client.read_resource(AnyUrl("config://missing")) assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) @@ -134,33 +135,33 @@ def app_config() -> str: async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: """An exception raised by a resource function reaches the caller as a JSON-RPC error. - MCPServer wraps the failure in a generic error that names only the URI, so the original - exception text is not leaked to the client. The wrapped exception becomes error code 0 the - same way every other unhandled server-side exception does. + FastMCP wraps the failure in a ResourceError whose message names the URI and appends the + original exception text, so on v1 the underlying message does reach the client. The wrapped + exception becomes error code 0 the same way every other unhandled server-side exception does. """ - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("res://boom") def boom() -> str: raise RuntimeError("nope") async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: - await client.read_resource("res://boom") + with pytest.raises(McpError) as exc_info: + await client.read_resource(AnyUrl("res://boom")) - assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom: nope")) @requirement("mcpserver:resource:duplicate-name") async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: """Registering a second static resource at an already-used URI keeps the first registration. - The intended behaviour is rejection at registration time; MCPServer instead logs a warning + The intended behaviour is rejection at registration time; FastMCP instead logs a warning and discards the second registration (see the divergence note on the requirement). The two registrations use different function names so the test does not redefine a name in this scope; the resource decorator keys on the URI, not the function name. """ - mcp = MCPServer("library") + mcp = FastMCP("library") @mcp.resource("config://app") def config_first() -> str: @@ -174,9 +175,9 @@ def config_second() -> str: async with connect(mcp) as client: listed = await client.list_resources() - result = await client.read_resource("config://app") + result = await client.read_resource(AnyUrl("config://app")) - assert [resource.uri for resource in listed.resources] == ["config://app"] + assert [resource.uri for resource in listed.resources] == [AnyUrl("config://app")] assert listed.resources[0].name == "config_first" assert result == snapshot( ReadResourceResult(contents=[TextResourceContents(uri="config://app", mimeType="text/plain", text="first")]) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index a4cf9ca348..060ef2d276 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -1,4 +1,4 @@ -"""Tool interactions against MCPServer, driven through the public Client API.""" +"""Tool interactions against FastMCP, driven through the public Client API.""" import logging from typing import Annotated, Literal @@ -7,17 +7,17 @@ from inline_snapshot import snapshot from pydantic import BaseModel, Field -from mcp import MCPError -from mcp.server.mcpserver import Context, MCPServer -from mcp.server.mcpserver.exceptions import ToolError +from mcp import McpError +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.exceptions import ToolError from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ( URL_ELICITATION_REQUIRED, CallToolResult, ElicitRequestURLParams, ErrorData, - LoggingMessageNotification, LoggingMessageNotificationParams, + ServerNotification, TextContent, ) from tests.interaction._connect import Connect @@ -31,10 +31,10 @@ async def test_call_tool_returns_text_content(connect: Connect) -> None: """Arguments reach the tool function; its return value comes back as text content. - MCPServer also derives an output schema from the return annotation and attaches the + FastMCP also derives an output schema from the return annotation and attaches the matching structuredContent to the result. """ - mcp = MCPServer("adder") + mcp = FastMCP("adder") @mcp.tool() def add(a: int, b: int) -> str: @@ -55,7 +55,7 @@ async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function body sees them, proving the generated input schema and validation pipeline cover non-trivial types. """ - mcp = MCPServer("typed") + mcp = FastMCP("typed") class Point(BaseModel): x: int @@ -86,7 +86,7 @@ async def test_call_tool_function_exception_becomes_error_result(connect: Connec result is built before any schema validation runs, so no validation failure is layered on top of the original exception. """ - mcp = MCPServer("errors") + mcp = FastMCP("errors") @mcp.tool() def explode() -> str: @@ -103,7 +103,7 @@ def explode() -> str: @requirement("mcpserver:tool:handler-throws") async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" - mcp = MCPServer("errors") + mcp = FastMCP("errors") @mcp.tool() def flux() -> str: @@ -126,7 +126,7 @@ async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> The spec classifies unknown tools as a protocol error; see the divergence note on the requirement. """ - mcp = MCPServer("errors") + mcp = FastMCP("errors") @mcp.tool() def add() -> None: @@ -146,7 +146,7 @@ async def test_call_tool_model_return_becomes_structured_content(connect: Connec """A tool returning a pydantic model advertises the model's schema as the tool's output schema and returns the model's fields as structured content alongside a serialised text block. """ - mcp = MCPServer("weather") + mcp = FastMCP("weather") class Weather(BaseModel): temperature: float @@ -160,7 +160,7 @@ def get_weather() -> Weather: listed = await client.list_tools() result = await client.call_tool("get_weather", {}) - assert listed.tools[0].output_schema == snapshot( + assert listed.tools[0].outputSchema == snapshot( { "properties": { "temperature": {"title": "Temperature", "type": "number"}, @@ -194,7 +194,7 @@ async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) """A tool returning a list wraps the value under a "result" key in both the generated output schema and the structured content. """ - mcp = MCPServer("primes") + mcp = FastMCP("primes") @mcp.tool() def primes() -> list[int]: @@ -204,7 +204,7 @@ def primes() -> list[int]: listed = await client.list_tools() result = await client.call_tool("primes", {}) - assert listed.tools[0].output_schema == snapshot( + assert listed.tools[0].outputSchema == snapshot( { "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, "required": ["result"], @@ -229,7 +229,7 @@ async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) """Arguments that fail validation against the tool's signature are reported as an is_error result describing the failure, not as a protocol error. """ - mcp = MCPServer("adder") + mcp = FastMCP("adder") @mcp.tool() def add(a: int, b: int) -> str: @@ -242,7 +242,7 @@ def add(a: int, b: int) -> str: # The description is raw pydantic output -- it embeds a pydantic-version-specific # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable # prefix is asserted; a full snapshot would break on every pydantic upgrade. - assert result.is_error is True + assert result.isError is True assert isinstance(result.content[0], TextContent) assert result.content[0].text.startswith("Error executing tool add: 1 validation error") @@ -255,12 +255,12 @@ async def test_tool_with_output_schema_returning_mismatched_structured_content_i """Structured content that fails the tool's own output schema is rejected on the server side. A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while - declaring `Model` as its output schema; MCPServer validates the supplied structured_content + declaring `Model` as its output schema; FastMCP validates the supplied structured_content against that schema before returning. The two cases -- a content shape that does not match, and no structured content at all -- both fail that validation and are reported as is_error results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. """ - mcp = MCPServer("forecaster") + mcp = FastMCP("forecaster") class Weather(BaseModel): temperature: float @@ -281,11 +281,11 @@ def missing() -> Annotated[CallToolResult, Weather]: # The body of each message is raw pydantic ValidationError output (model name, field paths, # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable # prefix is asserted. - assert mismatched_result.is_error is True + assert mismatched_result.isError is True assert isinstance(mismatched_result.content[0], TextContent) assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") - assert missing_result.is_error is True + assert missing_result.isError is True assert isinstance(missing_result.content[0], TextContent) assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") @@ -294,12 +294,12 @@ def missing() -> Annotated[CallToolResult, Weather]: async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: """Registering a second tool with an already-used name keeps the first registration. - The intended behaviour is rejection at registration time; MCPServer instead logs a warning + The intended behaviour is rejection at registration time; FastMCP instead logs a warning and discards the second registration (see the divergence note on the requirement). The second function is registered via add_tool with an explicit name so the test does not redefine the same function name in this scope. """ - mcp = MCPServer("duplicates") + mcp = FastMCP("duplicates") @mcp.tool() def echo() -> str: @@ -327,12 +327,12 @@ async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_re ) -> None: """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. - The intended behaviour is rejection at registration time; MCPServer instead logs the + The intended behaviour is rejection at registration time; FastMCP instead logs the naming-rule violation and proceeds (see the divergence note on the requirement). The warning spans several SDK-authored log records, so only the stable prefix and inclusion of the offending name are asserted. """ - mcp = MCPServer("naming") + mcp = FastMCP("naming") with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): @@ -361,11 +361,11 @@ def bad() -> str: async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. - MCPServer wraps every other tool exception as an is_error result; this error is special-cased + FastMCP wraps every other tool exception as an is_error result; this error is special-cased so it propagates as the JSON-RPC error the client needs in order to present the listed URL interactions and retry the call. """ - mcp = MCPServer("authorizer") + mcp = FastMCP("authorizer") @mcp.tool() def read_files() -> str: @@ -380,7 +380,7 @@ def read_files() -> str: ) async with connect(mcp) as client: - with pytest.raises(MCPError) as exc_info: + with pytest.raises(McpError) as exc_info: await client.call_tool("read_files", {}) assert exc_info.value.error.code == URL_ELICITATION_REQUIRED @@ -408,12 +408,12 @@ async def test_adding_and_removing_tools_does_not_notify_connected_clients(conne add_tool and remove_tool only update the registry: a connected client that listed the tools before the mutation has no way to learn it should list them again. The spec provides - notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + notifications/tools/list_changed for exactly this; FastMCP never sends it. The tool emits one log message as a sentinel so the test proves notifications do reach the collector -- the log message arrives, a list_changed does not. """ received: list[IncomingMessage] = [] - mcp = MCPServer("mutable") + mcp = FastMCP("mutable") def extra() -> str: """A tool registered at runtime; never called.""" @@ -441,6 +441,12 @@ async def collect(message: IncomingMessage) -> None: assert [tool.name for tool in before.tools] == ["doomed", "grow"] assert [tool.name for tool in after.tools] == ["grow", "extra"] - assert received == snapshot( - [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + # In v1 the message_handler receives ServerNotification RootModel envelopes whose inner + # Notification carries jsonrpc as an extra field, so the full object can't be compared to a + # constructed literal; unwrap and assert on (method, params) to prove only the log sentinel + # arrived and no tools/list_changed notification was sent. + notifications = [m.root for m in received if isinstance(m, ServerNotification)] + assert len(notifications) == len(received) + assert [(n.method, n.params) for n in notifications] == snapshot( + [("notifications/message", LoggingMessageNotificationParams(level="info", data="tool set changed"))] ) diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py index 96b70c37e6..f78826c5ed 100644 --- a/tests/interaction/transports/test_client_transport_http.py +++ b/tests/interaction/transports/test_client_transport_http.py @@ -7,6 +7,7 @@ """ from collections.abc import AsyncIterator +from typing import Any import anyio import httpx @@ -14,12 +15,18 @@ from inline_snapshot import snapshot from starlette.types import Receive, Scope, Send -from mcp import MCPError, types -from mcp.client.client import Client +from mcp import McpError +from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server import Server, ServerRequestContext -from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool -from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from mcp.server.lowlevel import Server +from mcp.types import CallToolResult, ErrorData, TextContent, Tool +from tests.interaction._connect import ( + BASE_URL, + NO_DNS_REBINDING_PROTECTION, + build_streamable_http_app, + client_via_http, + mounted_app, +) from tests.interaction._requirements import requirement from tests.interaction.transports._bridge import StreamingASGITransport from tests.interaction.transports._event_store import SequencedEventStore @@ -29,16 +36,18 @@ def _tooled_server() -> Server: """A low-level server with one echo tool, used by every test in this file.""" + server = Server("echoer") - async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="echo", description="Echo text.", inputSchema={"type": "object"})] - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "echo" - assert params.arguments is not None - return CallToolResult(content=[TextContent(type="text", text=str(params.arguments["text"]))]) + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + assert name == "echo" + return [TextContent(type="text", text=str(arguments["text"]))] - return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + return server @pytest.fixture @@ -166,7 +175,7 @@ async def test_client_tolerates_405_on_get_and_delete() -> None: Neither surfaces to the caller. """ server = _tooled_server() - real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + real_app, manager = build_streamable_http_app(server, transport_security=NO_DNS_REBINDING_PROTECTION) async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): @@ -176,12 +185,15 @@ async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: await real_app(scope, receive, send) async with ( - server.session_manager.run(), + manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, ): - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) with anyio.fail_after(5): # pragma: no branch - async with Client(transport) as client: # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write, _), + ClientSession(read, write) as client, + ): + await client.initialize() result = await client.list_tools() assert [tool.name for tool in result.tools] == ["echo"] @@ -222,7 +234,7 @@ async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> Non client-transport:http:session-404-reinitialize; this test pins the SDK's current behaviour. """ server = _tooled_server() - real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + real_app, manager = build_streamable_http_app(server, transport_security=NO_DNS_REBINDING_PROTECTION) initialize_seen = anyio.Event() async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: @@ -235,13 +247,16 @@ async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> Non await real_app(scope, receive, send) async with ( - server.session_manager.run(), + manager.run(), httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, ): - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) with anyio.fail_after(5): # pragma: no branch - async with Client(transport) as client: # pragma: no branch - with pytest.raises(MCPError) as exc_info: # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read, write, _), + ClientSession(read, write) as client, + ): + await client.initialize() + with pytest.raises(McpError) as exc_info: # pragma: no branch await client.list_tools() - assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) + assert exc_info.value.error == snapshot(ErrorData(code=32600, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py index e5d75a9f7c..b390aa09f2 100644 --- a/tests/interaction/transports/test_flows.py +++ b/tests/interaction/transports/test_flows.py @@ -1,8 +1,8 @@ """Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. These scenarios are about how the transport layer holds together across more than one connection -or more than one transport, so they connect real `Client`s against one mounted server rather than -running over the matrix. +or more than one transport, so they connect real `ClientSession`s against one mounted server rather +than running over the matrix. """ import anyio @@ -11,7 +11,7 @@ from inline_snapshot import snapshot from mcp.client.session import LoggingFnT -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app from tests.interaction._requirements import requirement @@ -27,7 +27,7 @@ async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_ independence under termination) with the notification-isolation dimension: a notification emitted by one session's handler does not leak to another session's client. """ - mcp = MCPServer("multi") + mcp = FastMCP("multi") @mcp.tool() async def announce(label: str, ctx: Context) -> str: @@ -67,7 +67,7 @@ async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_op (3) connect a second client to the same mounted app, (4) the second client's call_tool succeeds and the recorded session ids show two distinct sessions were issued. """ - mcp = MCPServer("reconnectable") + mcp = FastMCP("reconnectable") @mcp.tool() def echo(text: str) -> str: @@ -97,15 +97,15 @@ async def record(request: httpx.Request) -> None: @requirement("flow:compat:dual-transport-server") async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: - """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + """One FastMCP instance serves a streamable-HTTP client and a legacy-SSE client at the same time. The two transports have independent connection management (the streamable-HTTP session manager versus a per-connection SSE handler), but both dispatch into the same server's request handlers. The test connects one client over each transport against the same instance and - proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + proves both reach the same tool. Uses FastMCP because the low-level Server has no SSE convenience; the entry is about hosting composition, not the low-level API. """ - mcp = MCPServer("dual") + mcp = FastMCP("dual") @mcp.tool() def echo(text: str) -> str: diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index be0d7cbda2..3bd84f68e9 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -6,30 +6,30 @@ `connect`-fixture matrix. """ +from typing import Any + import anyio import pytest from anyio.lowlevel import checkpoint from httpx_sse import ServerSentEvent, aconnect_sse from inline_snapshot import snapshot +from pydantic import AnyUrl -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( INVALID_PARAMS, PARSE_ERROR, CallToolRequestParams, - CallToolResult, - EmptyResult, + ContentBlock, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - ListResourcesResult, - ListToolsResult, - PaginatedRequestParams, - SetLevelRequestParams, - SubscribeRequestParams, + LoggingLevel, + Resource, TextContent, + Tool, ) from tests.interaction._connect import ( base_headers, @@ -45,37 +45,37 @@ def _server() -> Server: """A low-level server with one tool that emits a related and an unrelated notification.""" + server = Server("hosted") - async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - """Registered only so the tools capability is advertised; never called.""" - raise NotImplementedError + @server.list_tools() + async def list_tools() -> list[Tool]: + """Registered so the tools capability is advertised; v1 also calls it from call_tool for schema caching.""" + return [] - async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - assert params.name == "narrate" + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[ContentBlock]: + assert name == "narrate" + ctx = server.request_context await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) - await ctx.session.send_resource_updated("file:///watched.txt") - return CallToolResult(content=[TextContent(type="text", text="done")]) + await ctx.session.send_resource_updated(AnyUrl("file:///watched.txt")) + return [TextContent(type="text", text="done")] - async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + @server.set_logging_level() + async def set_logging_level(level: LoggingLevel) -> None: """Registered so the logging capability is advertised; the client never sets a level.""" raise NotImplementedError - async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + @server.list_resources() + async def list_resources() -> list[Resource]: """Registered so the resources capability is advertised; the client never lists resources.""" raise NotImplementedError - async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + @server.subscribe_resource() + async def subscribe_resource(uri: AnyUrl) -> None: """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" raise NotImplementedError - return Server( - "hosted", - on_list_tools=list_tools, - on_call_tool=call_tool, - on_set_logging_level=set_logging_level, - on_list_resources=list_resources, - on_subscribe_resource=subscribe_resource, - ) + return server @requirement("hosting:http:method-405") @@ -289,8 +289,8 @@ async def read_standalone_stream() -> None: await seen_on_standalone.wait() tg.cancel_scope.cancel() - post_messages = parse_sse_messages(post_events) - get_messages = parse_sse_messages(get_events) + post_messages = [m.root for m in parse_sse_messages(post_events)] + get_messages = [m.root for m in parse_sse_messages(get_events)] # POST stream: the related log notification, then the response, then the iterator ends (close). assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) @@ -313,12 +313,12 @@ async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> No """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an - unconditional MUST, but the SDK enables it only when the host is localhost (or settings are - passed explicitly) and additionally checks the Host header (returning 421), which the spec - does not require. + unconditional MUST, but the v1 SDK enables it only when settings are passed explicitly + (no auto-enable on localhost) and additionally checks the Host header (returning 421), + which the spec does not require. """ - # transport_security=None triggers the localhost auto-enable behaviour. - async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + guarded = TransportSecuritySettings(allowed_hosts=["127.0.0.1:8000"], allowed_origins=["http://127.0.0.1:8000"]) + async with mounted_app(Server("guarded"), transport_security=guarded) as (http, _): bad_origin = await http.post( "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} ) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index a23311b39c..2b13c60a10 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -16,22 +16,24 @@ import pytest from httpx_sse import EventSource, ServerSentEvent from inline_snapshot import snapshot +from pydantic import AnyUrl from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.shared.message import ClientMessageMetadata from mcp.types import ( LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, + ClientRequest, + JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, LoggingMessageNotificationParams, TextContent, - jsonrpc_message_adapter, ) from tests.interaction._connect import ( BASE_URL, @@ -47,16 +49,16 @@ pytestmark = pytest.mark.anyio -def _counting_server() -> MCPServer: +def _counting_server() -> FastMCP: """A server with one tool that emits related notifications and one unrelated notification.""" - mcp = MCPServer("resumable") + mcp = FastMCP("resumable") @mcp.tool() async def count(ctx: Context, n: int) -> str: """Emit n log notifications related to this call, plus one unrelated resource update.""" for i in range(1, n + 1): await ctx.info(f"tick {i}") - await ctx.session.send_resource_updated("file:///elsewhere.txt") + await ctx.session.send_resource_updated(AnyUrl("file:///elsewhere.txt")) return f"counted to {n}" return mcp @@ -100,7 +102,7 @@ async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_ev assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( ["notifications/message", "notifications/message"] ) - assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + assert JSONRPCMessage.model_validate_json(result.data).root == snapshot( JSONRPCResponse( jsonrpc="2.0", id=1, @@ -133,7 +135,7 @@ async def test_get_with_last_event_id_replays_only_that_streams_missed_events() release = anyio.Event() store = SequencedEventStore() - mcp = MCPServer("resumable") + mcp = FastMCP("resumable") @mcp.tool() async def count(ctx: Context) -> str: @@ -142,7 +144,7 @@ async def count(ctx: Context) -> str: await release.wait() await ctx.info("tick 2") await ctx.info("tick 3") - await ctx.session.send_resource_updated("file:///elsewhere.txt") + await ctx.session.send_resource_updated(AnyUrl("file:///elsewhere.txt")) return "counted" async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): @@ -166,7 +168,7 @@ async def count(ctx: Context) -> str: assert replay.status_code == 200 missed = await _read_events(replay, 3) - decoded = parse_sse_messages(missed) + decoded = [envelope.root for envelope in parse_sse_messages(missed)] # Exactly the two remaining related notifications and the response, with their original IDs. assert [event.id for event in missed] == snapshot(["5", "6", "8"]) assert [type(message).__name__ for message in decoded] == snapshot( @@ -212,7 +214,7 @@ async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() release = anyio.Event() finished = anyio.Event() - mcp = MCPServer("resumable") + mcp = FastMCP("resumable") @mcp.tool() async def hold(ctx: Context) -> str: @@ -258,7 +260,7 @@ async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() gate = anyio.Event() done = anyio.Event() - mcp = MCPServer("resumable") + mcp = FastMCP("resumable") @mcp.tool() async def interrupt(ctx: Context) -> str: @@ -304,9 +306,9 @@ async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_conn This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the previous test covers: the transport dispatches a resumption_token request as a GET with Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new - request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a - bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client - cannot express. The second connection carries the original session id but does not initialize + request's id. ClientSession.call_tool does not expose ClientMessageMetadata, so the test drives + session.send_request directly. The second connection carries the original session id but does + not initialize (the server-side session already is), modelling a caller that resumes after a process restart. """ captured: list[str] = [] @@ -316,7 +318,7 @@ async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_conn release = anyio.Event() store = SequencedEventStore() - mcp = MCPServer("resumable") + mcp = FastMCP("resumable") @mcp.tool() async def hold(ctx: Context) -> str: @@ -335,13 +337,17 @@ async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params.data) first_seen.set() - call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + call = ClientRequest(CallToolRequest(params=CallToolRequestParams(name="hold", arguments={}))) capture = ClientMessageMetadata(on_resumption_token_update=on_token) - async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): with anyio.fail_after(5): # pragma: no branch async with ( # pragma: no branch - streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as ( + r1, + w1, + get_session_id, + ), ClientSession(r1, w1, logging_callback=collect) as first, anyio.create_task_group() as tg, ): @@ -351,8 +357,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None: await token_seen.wait() assert captured == snapshot(["3", "4"]) assert received == snapshot(["first"]) - # The session id is only observable via the manager (the client transport does not expose it). - (session_id,) = manager._server_instances + session_id = get_session_id() + assert session_id is not None http.headers["mcp-session-id"] = session_id http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION tg.cancel_scope.cancel() @@ -362,7 +368,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: # init priming + init response + call priming + "first" + "second" + result = 6 stored events. await store.wait_until_stored(6) async with ( # pragma: no branch - streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2, _), ClientSession(r2, w2, logging_callback=collect) as second, ): result = await second.send_request( diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py index 1285d4d81e..87daf59ece 100644 --- a/tests/interaction/transports/test_hosting_session.py +++ b/tests/interaction/transports/test_hosting_session.py @@ -13,8 +13,8 @@ import pytest from inline_snapshot import snapshot -from mcp.server import Server, ServerRequestContext -from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from mcp.server import Server +from mcp.types import JSONRPCResponse, ListToolsResult, Tool from tests.interaction._connect import ( base_headers, client_via_http, @@ -30,11 +30,13 @@ def _server() -> Server: """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + server = Server("hosted") - async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", inputSchema={"type": "object"})]) + @server.list_tools() + async def list_tools() -> list[Tool]: + return [Tool(name="noop", description="Does nothing.", inputSchema={"type": "object"})] - return Server("hosted", on_list_tools=list_tools) + return server @requirement("hosting:session:create") @@ -49,8 +51,8 @@ async def test_initialize_issues_a_visible_ascii_session_id() -> None: assert session_id is not None # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). assert re.fullmatch(r"[\x21-\x7E]+", session_id) - assert isinstance(messages[0], JSONRPCResponse) - assert messages[0].id == 1 + assert isinstance(messages[0].root, JSONRPCResponse) + assert messages[0].root.id == 1 @requirement("hosting:session:reuse") @@ -78,7 +80,7 @@ async def test_requests_with_an_unknown_session_id_return_404() -> None: delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) assert (post.status_code, post.json()) == snapshot( - (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + (404, {"jsonrpc": "2.0", "id": "server-error", "error": {"code": -32600, "message": "Session not found"}}) ) assert (get.status_code, delete.status_code) == (404, 404) @@ -93,7 +95,14 @@ async def test_non_initialize_post_without_a_session_id_returns_400() -> None: ) assert (response.status_code, response.json()) == snapshot( - (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ( + 400, + { + "jsonrpc": "2.0", + "id": "server-error", + "error": {"code": -32600, "message": "Bad Request: Missing session ID"}, + }, + ) ) @@ -120,7 +129,7 @@ async def test_delete_terminates_the_session_and_subsequent_requests_return_404( 404, { "jsonrpc": "2.0", - "id": None, + "id": "server-error", "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, }, ) @@ -154,8 +163,8 @@ async def test_second_initialize_on_an_existing_session_is_accepted() -> None: assert len(manager._server_instances) == 1 assert response.status_code == snapshot(200) - assert isinstance(messages[0], JSONRPCResponse) - assert messages[0].id == 2 + assert isinstance(messages[0].root, JSONRPCResponse) + assert messages[0].root.id == 2 @requirement("hosting:stateless:no-session-id") From 6c7cbcc03ee09c80cb6d0f4603937f86b2ee9def Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:04:11 +0000 Subject: [PATCH 12/19] =?UTF-8?q?backport:=20phase-5=20cleanup=20=E2=80=94?= =?UTF-8?q?=20apply=20pending=20deferrals,=20test=5Fcoverage=20green?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the one N-risk deferral (transport:streamable-http:stateless-restrictions) that phase-4 wave 2 logged in cant-express-v1.md but did not apply to the manifest. The other three v1-API-gap deferrals (roots:list-changed, resources:templates:pagination, client-auth:authorize:offline-access-consent) were already applied by their Y-risk sequential-lane batches. test_deferral_reasons_cite_existing_paths was already green (no v2-only paths cited). Full suite: 524 pass / 0 fail. 67 deferred entries (4 v1-API-gap). --- tests/interaction/_requirements.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 629a05b9e3..c9bd60a1ac 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1798,6 +1798,12 @@ def __post_init__(self) -> None: "result, because there is no session to call back through." ), transports=("streamable-http",), + deferred=( + "Not expressible via the v1 public API: v1 stateless mode has no guard on server-initiated " + "requests; the request is sent, the response arrives at a fresh stateless transport instance " + "and is dropped as an unknown request id, and the handler waits forever. The only observable " + "outcome is a hang, which is not pinnable without a time-based wait." + ), ), "transport:streamable-http:unrelated-messages": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", From c59a57d9dbed064d5dcec56612db1aa282d30149 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:11:18 +0000 Subject: [PATCH 13/19] =?UTF-8?q?backport:=20phase-5=20fixup=20=E2=80=94?= =?UTF-8?q?=20port=20roots:list-changed=20(audit-1),=20scope=20warning=20f?= =?UTF-8?q?ilter,=20ASGIApp=20note?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/interaction/_connect.py | 19 ++++++++++++- tests/interaction/_requirements.py | 4 --- tests/interaction/auth/_harness.py | 3 +- tests/interaction/conftest.py | 10 +++++-- tests/interaction/lowlevel/test_roots.py | 36 ++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 10 deletions(-) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 0bcf6fb416..ca5d6fe7bc 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -21,6 +21,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client @@ -32,7 +33,6 @@ from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP -from mcp.server.fastmcp.server import StreamableHTTPASGIApp from mcp.server.sse import SseServerTransport from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -59,6 +59,23 @@ NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) +class StreamableHTTPASGIApp: + """Thin ASGI wrapper around `StreamableHTTPSessionManager.handle_request`. + + Starlette's `Route(path, endpoint=...)` treats a *class instance* as a raw ASGI callable + (matching all HTTP verbs), whereas a coroutine function is wrapped via `request_response` + and defaults to GET/HEAD only. v1's `FastMCP.streamable_http_app()` relies on this same + distinction; we inline the wrapper here rather than deep-importing the (non-`__all__`) + `mcp.server.fastmcp.server.StreamableHTTPASGIApp`. + """ + + def __init__(self, session_manager: StreamableHTTPSessionManager) -> None: + self.session_manager = session_manager + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.session_manager.handle_request(scope, receive, send) + + def _lowlevel(server: Server[Any] | FastMCP) -> Server[Any]: """Return the lowlevel `Server` for either flavour. diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index c9bd60a1ac..b8e2d57c97 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1450,10 +1450,6 @@ def __post_init__(self) -> None: "roots:list-changed": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", - deferred=( - "Not expressible via the v1 public API: the low-level Server exposes no decorator for " - "notifications/roots/list_changed, so a server handler cannot be registered to observe delivery." - ), ), "roots:list-changed:client-emits": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index 443f2f3f9d..410d49c515 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -34,10 +34,9 @@ from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions -from mcp.server.fastmcp.server import StreamableHTTPASGIApp from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken -from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, StreamableHTTPASGIApp from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider from tests.interaction.transports._bridge import StreamingASGITransport diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index 3f389cedba..ea93de29d5 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -14,9 +14,13 @@ def pytest_configure(config: pytest.Config) -> None: # session manager's per-session task-group cancel can race the per-request cleanup). v1's own # tests run the transport in a separate process and so never observe these `__del__`-time # ResourceWarnings; running in-process via the streaming bridge does. The fixes live in `src/` - # on `main` and are out of scope for this tests-only backport, so suppress here. - config.addinivalue_line("filterwarnings", "ignore::pytest.PytestUnraisableExceptionWarning") - config.addinivalue_line("filterwarnings", "ignore::ResourceWarning") + # on `main` and are out of scope for this tests-only backport — tracked in + # `notes/backport/issues.md`. The filters below are scoped to anyio's `MemoryObject*Stream` + # leak signature so an unrelated leak still fails the suite. + config.addinivalue_line( + "filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning" + ) + config.addinivalue_line("filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning") _FACTORIES: dict[str, Connect] = { diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 14ac80d46e..86b9303a6e 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -2,6 +2,7 @@ from typing import Any +import anyio import pytest from inline_snapshot import snapshot from pydantic import FileUrl @@ -140,3 +141,38 @@ async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsRe result = await client.call_tool("show_roots", {}) assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + v1's low-level `Server` exposes no decorator for this notification; the public path is direct + assignment into `Server.notification_handlers` (a public, typed dict that the server's dispatch + loop consults for every incoming client notification). The handler receives the notification + object itself. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.RootsListChangedNotification] = [] + + async def on_roots_changed(notify: types.RootsListChangedNotification) -> None: + received.append(notify) + delivered.set() + + server = Server("rooted") + server.notification_handlers[types.RootsListChangedNotification] = on_roots_changed + + async def list_roots(context: RequestContext[ClientSession, Any]) -> ListRootsResult | ErrorData: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert len(received) == 1 + assert received[0].method == "notifications/roots/list_changed" From 1f981e5230d715aaf83bf5c8278f5205e5bf0eac Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:28:15 +0000 Subject: [PATCH 14/19] backport: scrub local-path reference from conftest comment; clarify pyproject exclude note --- pyproject.toml | 3 ++- tests/interaction/conftest.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3b836470ff..a20c9987a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,8 @@ packages = ["src/mcp"] [tool.pyright] typeCheckingMode = "strict" include = ["src/mcp", "tests", "examples/servers", "examples/snippets"] -# tests/interaction is mid-backport from main; type-checking is restored phase by phase. +# tests/interaction is backported from `main` and uses v1's runtime API; strict-mode type-checking +# of the suite is tracked separately from this tests-only backport. exclude = ["tests/interaction"] venvPath = "." venv = ".venv" diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index ea93de29d5..aa8dc92a87 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -14,9 +14,8 @@ def pytest_configure(config: pytest.Config) -> None: # session manager's per-session task-group cancel can race the per-request cleanup). v1's own # tests run the transport in a separate process and so never observe these `__del__`-time # ResourceWarnings; running in-process via the streaming bridge does. The fixes live in `src/` - # on `main` and are out of scope for this tests-only backport — tracked in - # `notes/backport/issues.md`. The filters below are scoped to anyio's `MemoryObject*Stream` - # leak signature so an unrelated leak still fails the suite. + # on `main` and are out of scope for this tests-only backport. The filters below are scoped to + # anyio's `MemoryObject*Stream` leak signature so an unrelated leak still fails the suite. config.addinivalue_line( "filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning" ) From 361bb9d3ebda0b9cbad230e46cf799a80c337323 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 15:40:16 +0000 Subject: [PATCH 15/19] =?UTF-8?q?backport:=20coverage=20=E2=80=94=20recogn?= =?UTF-8?q?ize=20'lax=20no=20cover';=20pragma=20harness=20scaffolding=20+?= =?UTF-8?q?=203.11=20dead-zone=20lines?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pyproject: add 'pragma: lax no cover' to exclude_lines (matches main) - _connect.py: Protocol stub body, verifier-gate no-branch, SSE return-Response no-cover - _harness.py: resource_server_url gate no-branch - test_initialize/test_wire: tg.cancel_scope.cancel() lax-no-cover (cpython#106749, 3.11 only) - test_sse: nested-async-with no-branch (3.11 arc) All harness scaffolding; no port bugs. Local ./scripts/test → 100.00%. --- pyproject.toml | 1 + tests/interaction/_connect.py | 7 ++++--- tests/interaction/auth/_harness.py | 2 +- tests/interaction/lowlevel/test_initialize.py | 4 ++-- tests/interaction/lowlevel/test_wire.py | 2 +- tests/interaction/transports/test_sse.py | 2 +- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a20c9987a4..a8865c6622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,6 +213,7 @@ ignore_errors = true precision = 2 exclude_lines = [ "pragma: no cover", + "pragma: lax no cover", "if TYPE_CHECKING:", "@overload", "raise NotImplementedError", diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index ca5d6fe7bc..53e6ffb0eb 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -103,7 +103,7 @@ def __call__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: Implementation | None = None, - ) -> AbstractAsyncContextManager[ClientSession]: ... + ) -> AbstractAsyncContextManager[ClientSession]: ... # pragma: no cover @asynccontextmanager @@ -188,7 +188,7 @@ def build_streamable_http_app( if auth is not None: required_scopes = auth.required_scopes or [] - if verifier is not None: + if verifier is not None: # pragma: no branch — every auth-bearing caller supplies a provider/verifier middleware = [ Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(verifier)), Middleware(AuthContextMiddleware), @@ -437,7 +437,8 @@ def build_sse_app(server: Server[Any] | FastMCP) -> tuple[Starlette, SseServerTr async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): # type: ignore[reportPrivateUsage] await lowlevel.run(read, write, lowlevel.create_initialization_options()) - return Response() + # under StreamingASGITransport the request is cancelled on close, so run() never returns + return Response() # pragma: no cover app = Starlette( routes=[ diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index 410d49c515..81d8927777 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -477,7 +477,7 @@ async def connect_with_oauth( else: routes.append(Route("/mcp", endpoint=asgi)) - if settings.resource_server_url: + if settings.resource_server_url: # pragma: no branch — auth_settings() always sets this routes.extend( create_protected_resource_routes( resource_url=settings.resource_server_url, diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 0260975258..e1451c68d0 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -70,7 +70,7 @@ async def _initialize(server: Server[Any]) -> InitializeResult: async with ClientSession(client_read, client_write) as session: with anyio.fail_after(5): initialize_result = await session.initialize() - tg.cancel_scope.cancel() + tg.cancel_scope.cancel() # pragma: lax no cover — python/cpython#106749 (3.11 tracer dead-zone) return initialize_result @@ -88,7 +88,7 @@ async def _bare_session(server: Server[Any]) -> AsyncIterator[ClientSession]: tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options())) async with ClientSession(client_read, client_write) as session: yield session - tg.cancel_scope.cancel() + tg.cancel_scope.cancel() # pragma: lax no cover — python/cpython#106749 (3.11 tracer dead-zone) @requirement("lifecycle:initialize:basic") diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index ff1bd9df6e..cd83963814 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -83,7 +83,7 @@ async def _record( await client.initialize() yield client, recording finally: - tg.cancel_scope.cancel() + tg.cancel_scope.cancel() # pragma: lax no cover — python/cpython#106749 (3.11 tracer dead-zone) @requirement("protocol:request-id:unique") diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index 5fac87a628..dd3a9594e9 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -54,7 +54,7 @@ def httpx_client_factory( async with sse_client( f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append ) as (read, write): - async with ClientSession(read, write) as client: + async with ClientSession(read, write) as client: # pragma: no branch await client.initialize() assert len(captured_session_id) == 1 assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers From 04980e32c7998050aa99450675d266c21593289d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 16:11:14 +0000 Subject: [PATCH 16/19] =?UTF-8?q?backport:=20coverage=20=E2=80=94=20raise?= =?UTF-8?q?=20dev-dep=20floors=20for=20tests/interaction=20(pytest=208.4,?= =?UTF-8?q?=20sse-starlette=202.1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lowest-direct only; runtime floors unchanged. - pytest>=8.4.0: tests/interaction/auth uses pytest.RaisesGroup (8.4+); matches main. - sse-starlette>=2.1.0 (dev group): 1.x keeps a module-global anyio.Event (AppStatus.should_exit_event) bound to the first test's event loop, which breaks every subsequent in-process SSE response under the streaming bridge. v1's own tests run uvicorn in a subprocess so don't observe this. Runtime floor stays >=1.6.1 (the SDK works with 1.6.1 under uvicorn). --- pyproject.toml | 7 ++++++- uv.lock | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a8865c6622..b62a0057a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ required-version = ">=0.9.5" [dependency-groups] dev = [ "pyright>=1.1.400", - "pytest>=8.3.4", + "pytest>=8.4.0", "ruff>=0.8.5", "trio>=0.26.2", "pytest-flakefinder>=1.1.0", @@ -63,6 +63,11 @@ dev = [ "inline-snapshot>=0.23.0", "dirty-equals>=0.9.0", "coverage[toml]==7.10.7", + # tests/interaction runs the streamable-HTTP and SSE apps in-process; sse-starlette 1.x + # keeps a module-global anyio.Event (`AppStatus.should_exit_event`) bound to the first + # test's event loop, which breaks every subsequent in-process SSE response. The runtime + # floor stays >=1.6.1 (works under uvicorn); only the test environment needs >=2.x. + "sse-starlette>=2.1.0", ] docs = [ "mkdocs>=1.6.1", diff --git a/uv.lock b/uv.lock index 4b01712966..a3450707c3 100644 --- a/uv.lock +++ b/uv.lock @@ -807,6 +807,7 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "sse-starlette" }, { name = "trio" }, ] docs = [ @@ -845,12 +846,13 @@ dev = [ { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "pyright", specifier = ">=1.1.400" }, - { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, + { name = "sse-starlette", specifier = ">=2.1.0" }, { name = "trio", specifier = ">=0.26.2" }, ] docs = [ From 52a79c4ec06380d44441966b39280106660e0f36 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 16:38:02 +0000 Subject: [PATCH 17/19] =?UTF-8?q?backport:=20coverage=20=E2=80=94=20reset?= =?UTF-8?q?=20sse-starlette's=20module-global=20exit=20Event=20between=20t?= =?UTF-8?q?ests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sse-starlette <3.0 stores an anyio.Event on AppStatus the first time an EventSourceResponse runs, bound to that test's event loop. Under the in-process bridge (one process, per-test event loops) every subsequent SSE response — both the [sse] leg and connect_with_oauth's streamable-HTTP responses — fails with 'bound to a different event loop'. v1's own transport tests run uvicorn in a subprocess and so never share a process across event loops. The previous commit's >=2.1.0 dev floor was insufficient (2.1.0 still has the class attribute; 3.x switched to a ContextVar). An autouse fixture that resets the attribute after each test handles all versions including the 1.6.1 runtime floor, so the dev floor is reverted and lowest-direct CI again exercises the runtime constraint. Verified: 1598 tests / 100.00% coverage on both highest and lowest-direct. --- pyproject.toml | 5 ----- tests/interaction/conftest.py | 19 +++++++++++++++++++ uv.lock | 2 -- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b62a0057a5..0f7baeffbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,11 +63,6 @@ dev = [ "inline-snapshot>=0.23.0", "dirty-equals>=0.9.0", "coverage[toml]==7.10.7", - # tests/interaction runs the streamable-HTTP and SSE apps in-process; sse-starlette 1.x - # keeps a module-global anyio.Event (`AppStatus.should_exit_event`) bound to the first - # test's event loop, which breaks every subsequent in-process SSE response. The runtime - # floor stays >=1.6.1 (works under uvicorn); only the test environment needs >=2.x. - "sse-starlette>=2.1.0", ] docs = [ "mkdocs>=1.6.1", diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index aa8dc92a87..597a87082c 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -1,6 +1,9 @@ """Shared fixtures for the interaction suite.""" +from collections.abc import Iterator + import pytest +from sse_starlette.sse import AppStatus from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http @@ -29,6 +32,22 @@ def pytest_configure(config: pytest.Config) -> None: } +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event after each test. + + sse-starlette <3.0 stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response (RuntimeError "bound to a different event loop", surfacing + as 5-second timeouts in `connect_with_oauth` and "Child exited" on the [sse] leg). v1's own + transport tests run uvicorn in a subprocess and so never share a process across event loops. + sse-starlette 3.x switched to a ContextVar (`_exit_event_context`) and has no such attribute. + """ + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + AppStatus.should_exit_event = None # pragma: lax no cover + + @pytest.fixture(params=sorted(_FACTORIES)) def connect(request: pytest.FixtureRequest) -> Connect: """The transport-parametrized connection factory: a test using it runs once per transport. diff --git a/uv.lock b/uv.lock index a3450707c3..dbf687f863 100644 --- a/uv.lock +++ b/uv.lock @@ -807,7 +807,6 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, - { name = "sse-starlette" }, { name = "trio" }, ] docs = [ @@ -852,7 +851,6 @@ dev = [ { name = "pytest-pretty", specifier = ">=1.2.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, - { name = "sse-starlette", specifier = ">=2.1.0" }, { name = "trio", specifier = ">=0.26.2" }, ] docs = [ From 67b3b2c3875b043ef2e416e6c04bc4cd918ca91d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 16:48:47 +0000 Subject: [PATCH 18/19] =?UTF-8?q?backport:=20coverage=20=E2=80=94=20lax-no?= =?UTF-8?q?-cover=20post-unwind=20asserts=20on=203.11+lowest-direct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 3.11 zero-cost-exception tracer dead-zone (python/cpython#106749) drops line events for sync statements between a cancel-on-exit __aexit__ and the next real await. On anyio 4.5.0 (lowest-direct) the unwind path differs from anyio 4.10 just enough that three additional post-async-with assertions fall in the dead-zone on 3.11 only. Same family as the markers added in 361bb9d. --- tests/interaction/transports/test_flows.py | 4 ++-- tests/interaction/transports/test_sse.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py index b390aa09f2..874d782bfa 100644 --- a/tests/interaction/transports/test_flows.py +++ b/tests/interaction/transports/test_flows.py @@ -121,9 +121,9 @@ def echo(text: str) -> str: shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) - assert shttp_result == snapshot( + assert shttp_result == snapshot( # pragma: lax no cover CallToolResult(content=[TextContent(type="text", text="via http")], structuredContent={"result": "via http"}) ) - assert sse_result == snapshot( + assert sse_result == snapshot( # pragma: lax no cover CallToolResult(content=[TextContent(type="text", text="via sse")], structuredContent={"result": "via sse"}) ) diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index dd3a9594e9..f7828b2f3a 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -60,7 +60,7 @@ def httpx_client_factory( assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers assert await client.send_ping() == snapshot(EmptyResult()) - assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers # pragma: lax no cover @requirement("transport:sse:post:session-routing") From 03fdaed27c2520fcd91a41ab6407e1c328c2cbdb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 17:17:14 +0000 Subject: [PATCH 19/19] tests: expect SSE session-entry cleanup on disconnect connect_sse now removes _read_stream_writers[session_id] in a finally once the GET request unwinds (#2719), so the endpoint-event test waits for that cleanup after the client disconnects instead of pinning the old retention behaviour. --- tests/interaction/transports/test_sse.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index f7828b2f3a..4abfffdebd 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -29,11 +29,8 @@ @requirement("transport:sse:endpoint-event") async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: """Connecting opens a GET stream whose first event names the POST endpoint and a fresh - session id; messages POSTed there are answered on that stream. - - On v1 the server's session entry is not removed on disconnect (`SseServerTransport` never - pops `_read_stream_writers[session_id]`); the final assertion pins that behaviour. - """ + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" app, sse = build_sse_app(Server("legacy")) captured_session_id: list[str] = [] @@ -60,7 +57,13 @@ def httpx_client_factory( assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers assert await client.send_ping() == snapshot(EmptyResult()) - assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers # pragma: lax no cover + # `connect_sse` drops the session entry in a `finally` once the GET request has unwound; the + # bridge lets that unwinding finish after the client has gone, so wait for the cleanup instead + # of racing it. How many iterations that takes is a scheduling accident (usually zero), and on + # 3.11 these post-unwind lines are invisible to the line tracer, hence the coverage exclusion. + with anyio.fail_after(5): # pragma: lax no cover + while sse._read_stream_writers: + await anyio.sleep(0.01) @requirement("transport:sse:post:session-routing")