diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e469..a4ecea4c4a 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,10 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager +from multiprocessing.connection import Connection import pytest from starlette.applications import Starlette @@ -19,7 +19,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import running_server # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,7 +41,7 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover +def run_unicode_server(port_writer: Connection) -> None: # pragma: no cover """Run the Unicode test server in a separate process.""" import uvicorn @@ -137,43 +137,28 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: lifespan=lifespan, ) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() + # Run the server config = uvicorn.Config( app=app, - host="127.0.0.1", - port=port, log_level="error", ) uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] + uvicorn_server.run(sockets=[sock]) @pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: +def running_unicode_server() -> Generator[str, None, None]: """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + with running_server(run_unicode_server) as url: + yield url @pytest.mark.anyio diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..b7e33bd401 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -3,6 +3,7 @@ import logging import multiprocessing import socket +from multiprocessing.connection import Connection import httpx import pytest @@ -16,24 +17,11 @@ from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - class SecurityTestServer(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -42,7 +30,9 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover +def run_server_with_settings( + port_writer: Connection, security_settings: TransportSecuritySettings | None = None +): # pragma: no cover """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) @@ -63,28 +53,46 @@ async def handle_sse(request: Request): ] starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() + server = uvicorn.Server(config=uvicorn.Config(app=starlette_app, log_level="error")) + server.run(sockets=[sock]) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): + +def start_server_process( + security_settings: TransportSecuritySettings | None = None, +) -> tuple[multiprocessing.Process, int]: """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + reader, writer = multiprocessing.Pipe(duplex=False) + process = multiprocessing.Process( + target=run_server_with_settings, + kwargs={"port_writer": writer, "security_settings": security_settings}, + ) process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + writer.close() + try: + port = reader.recv() + finally: + reader.close() + return process, port @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): +async def test_sse_security_default_settings(): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) + process, port = start_server_process() try: headers = {"Host": "evil.com", "Origin": "http://evil.com"} async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response: assert response.status_code == 200 finally: process.terminate() @@ -92,18 +100,18 @@ async def test_sse_security_default_settings(server_port: int): @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: # Test with invalid host header headers = {"Host": "evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" @@ -113,20 +121,20 @@ async def test_sse_security_invalid_host_header(server_port: int): @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: # Test with invalid origin header headers = {"Origin": "http://evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers) assert response.status_code == 403 assert response.text == "Invalid Origin header" @@ -136,20 +144,20 @@ async def test_sse_security_invalid_origin_header(server_port: int): @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", headers={"Content-Type": "text/plain"}, content="test", ) @@ -158,7 +166,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int): # Test POST with missing content type response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" + f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", content="test" ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" @@ -169,10 +177,10 @@ async def test_sse_security_post_invalid_content_type(server_port: int): @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) + process, port = start_server_process(settings) try: # Test with invalid host header - should still work @@ -180,7 +188,7 @@ async def test_sse_security_disabled(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response: # Should connect successfully even with invalid host assert response.status_code == 200 @@ -190,14 +198,14 @@ async def test_sse_security_disabled(server_port: int): @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) + process, port = start_server_process(settings) try: # Test with custom allowed host @@ -205,7 +213,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response: # Should connect successfully with custom host assert response.status_code == 200 @@ -213,7 +221,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): headers = {"Host": "evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" @@ -223,14 +231,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) + process, port = start_server_process(settings) try: # Test with various port numbers @@ -239,7 +247,7 @@ async def test_sse_security_wildcard_ports(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 @@ -247,7 +255,7 @@ async def test_sse_security_wildcard_ports(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 @@ -257,13 +265,13 @@ async def test_sse_security_wildcard_ports(server_port: int): @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: async with httpx.AsyncClient() as client: @@ -279,7 +287,7 @@ async def test_sse_security_post_valid_content_type(server_port: int): # Use a valid UUID format (even though session won't exist) fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", headers={"Content-Type": content_type}, json={"test": "data"}, ) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..cc30c1fca9 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -4,6 +4,7 @@ import socket from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from multiprocessing.connection import Connection import httpx import pytest @@ -16,23 +17,10 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server SERVER_NAME = "test_streamable_http_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - class SecurityTestServer(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -41,7 +29,9 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover +def run_server_with_settings( + port_writer: Connection, security_settings: TransportSecuritySettings | None = None +): # pragma: no cover """Run the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() @@ -68,29 +58,47 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: ] starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() + server = uvicorn.Server(config=uvicorn.Config(app=starlette_app, log_level="error")) + server.run(sockets=[sock]) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): + +def start_server_process( + security_settings: TransportSecuritySettings | None = None, +) -> tuple[multiprocessing.Process, int]: """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + reader, writer = multiprocessing.Pipe(duplex=False) + process = multiprocessing.Process( + target=run_server_with_settings, + kwargs={"port_writer": writer, "security_settings": security_settings}, + ) process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + writer.close() + try: + port = reader.recv() + finally: + reader.close() + return process, port @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): +async def test_streamable_http_security_default_settings(): """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) + process, port = start_server_process() try: # Test with valid localhost headers async with httpx.AsyncClient(timeout=5.0) as client: # POST request to initialize session response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers={ "Accept": "application/json, text/event-stream", @@ -106,10 +114,10 @@ async def test_streamable_http_security_default_settings(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): +async def test_streamable_http_security_invalid_host_header(): """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: # Test with invalid host header @@ -121,7 +129,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) @@ -134,10 +142,10 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): +async def test_streamable_http_security_invalid_origin_header(): """Test StreamableHTTP with invalid Origin header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: # Test with invalid origin header @@ -149,7 +157,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) @@ -162,15 +170,15 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): +async def test_streamable_http_security_invalid_content_type(): """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) + process, port = start_server_process() try: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", headers={ "Content-Type": "text/plain", "Accept": "application/json, text/event-stream", @@ -182,7 +190,7 @@ async def test_streamable_http_security_invalid_content_type(server_port: int): # Test POST with missing content type response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", headers={"Accept": "application/json, text/event-stream"}, content="test", ) @@ -195,10 +203,10 @@ async def test_streamable_http_security_invalid_content_type(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): +async def test_streamable_http_security_disabled(): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) + process, port = start_server_process(settings) try: # Test with invalid host header - should still work @@ -210,7 +218,7 @@ async def test_streamable_http_security_disabled(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) @@ -223,14 +231,14 @@ async def test_streamable_http_security_disabled(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): +async def test_streamable_http_security_custom_allowed_hosts(): """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) + process, port = start_server_process(settings) try: # Test with custom allowed host @@ -242,7 +250,7 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"http://127.0.0.1:{port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) @@ -254,10 +262,10 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): +async def test_streamable_http_security_get_request(): """Test StreamableHTTP GET request with security.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) + process, port = start_server_process(security_settings) try: # Test GET request with invalid host header @@ -267,7 +275,7 @@ async def test_streamable_http_security_get_request(server_port: int): } async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + response = await client.get(f"http://127.0.0.1:{port}/", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" @@ -280,7 +288,7 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # GET requests need a session ID in StreamableHTTP # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + response = await client.get(f"http://127.0.0.1:{port}/", headers=headers) # This should pass security but fail on session validation assert response.status_code == 400 body = response.json() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707b..2712bbfdc9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,7 @@ import json -import multiprocessing import socket from collections.abc import AsyncGenerator, Generator +from multiprocessing.connection import Connection from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -41,21 +41,20 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import running_server SERVER_NAME = "test_server_for_SSE" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@pytest.fixture() +def server() -> Generator[str, None, None]: + with running_server(run_server) as url: + yield url -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +@pytest.fixture() +def server_url(server: str) -> str: + return server async def _handle_read_resource( # pragma: no cover @@ -127,35 +126,23 @@ async def handle_sse(request: Request) -> Response: return app -def run_server(server_port: int) -> None: # pragma: no cover +def run_server(port_writer: Connection) -> None: # pragma: no cover app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") + server = uvicorn.Server(config=uvicorn.Config(app=app, log_level="error")) + print(f"starting server on {port}") + server.run(sockets=[sock]) @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" async with httpx.AsyncClient(base_url=server_url) as client: yield client @@ -297,37 +284,31 @@ async def test_sse_client_timeout( # pragma: no cover pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: # pragma: no cover +def run_mounted_server(port_writer: Connection) -> None: # pragma: no cover app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() + server = uvicorn.Server(config=uvicorn.Config(app=main_app, log_level="error")) + print(f"starting server on {port}") + server.run(sockets=[sock]) -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +@pytest.fixture() +def mounted_server() -> Generator[str, None, None]: + with running_server(run_mounted_server) as url: + yield url @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: +async def test_sse_client_basic_connection_mounted_app(mounted_server: str) -> None: + async with sse_client(mounted_server + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -381,7 +362,7 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover +def run_context_server(port_writer: Connection) -> None: # pragma: no cover """Run a server that captures request context""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( @@ -406,33 +387,28 @@ async def handle_sse(request: Request) -> Response: ] ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port = sock.getsockname()[1] + port_writer.send(port) + port_writer.close() + + server = uvicorn.Server(config=uvicorn.Config(app=app, log_level="error")) + print(f"starting context server on {port}") + server.run(sockets=[sock]) @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: +def context_server() -> Generator[str, None, None]: """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") + with running_server(run_context_server) as url: + yield url @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(context_server: str) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -441,7 +417,7 @@ async def test_request_context_propagation(context_server: None, server_url: str "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( + async with sse_client(context_server + "/sse", headers=custom_headers) as ( read_stream, write_stream, ): @@ -465,7 +441,7 @@ async def test_request_context_propagation(context_server: None, server_url: str @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(context_server: str) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] @@ -473,7 +449,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( + async with sse_client(context_server + "/sse", headers=headers) as ( read_stream, write_stream, ): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..8ea323a0e3 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,13 +6,13 @@ from __future__ import annotations as _annotations import json -import multiprocessing import socket import time import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field +from multiprocessing.connection import Connection from typing import Any from unittest.mock import MagicMock from urllib.parse import urlparse @@ -66,7 +66,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import running_server # Test constants SERVER_NAME = "test_streamable_http_server" @@ -433,73 +433,62 @@ def create_app( def run_server( - port: int, + port_writer: Connection, is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, ) -> None: # pragma: no cover - """Run the test server. + """Run the test server in a subprocess. + + Binds an ephemeral listening socket, reports the chosen port back to the + parent through `port_writer`, then serves on that same socket. Binding in + the child (instead of picking a port in the parent, closing it, and + rebinding here) closes the window where another pytest-xdist worker could + grab the port in between. Args: - port: Port to listen on. + port_writer: Pipe end used to send the bound port back to the parent. is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. """ - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port_writer.send(sock.getsockname()[1]) + port_writer.close() + config = uvicorn.Config( app=app, - host="127.0.0.1", - port=port, log_level="info", limit_concurrency=10, timeout_keep_alive=5, access_log=False, ) - - # Start the server server = uvicorn.Server(config=config) # This is important to catch exceptions and prevent test hangs try: - server.run() + server.run(sockets=[sock]) except Exception: traceback.print_exc() -# Test fixtures - using same approach as SSE tests +# Test fixtures - each server runs in a subprocess that binds its own port +# (see `running_server`), so parallel test workers can't collide on a port. @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +def basic_server() -> Generator[str, None, None]: + """Start a basic server; yields its base URL.""" + with running_server(run_server) as url: + yield url @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def basic_server_url(basic_server: str) -> str: + """Get the URL for the basic test server.""" + return basic_server @pytest.fixture @@ -509,65 +498,23 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store and retry_interval enabled; yields (store, url).""" + with running_server(run_server, event_store=event_store, retry_interval=500) as url: + yield event_store, url @pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +def json_response_server() -> Generator[str, None, None]: + """Start a server with JSON responses enabled; yields its base URL.""" + with running_server(run_server, is_json_response_enabled=True) as url: + yield url @pytest.fixture -def json_server_url(json_server_port: int) -> str: +def json_server_url(json_response_server: str) -> str: """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" + return json_response_server # Basic request validation tests @@ -1518,8 +1465,8 @@ async def _handle_context_call_tool( # pragma: no cover # Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def run_context_aware_server(port_writer: Connection) -> None: # pragma: no cover + """Run the context-aware test server in a subprocess (see `run_server`).""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, @@ -1540,36 +1487,25 @@ def run_context_aware_server(port: int): # pragma: no cover lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen() + port_writer.send(sock.getsockname()[1]) + port_writer.close() -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() + uvicorn.Server(config=uvicorn.Config(app=app, log_level="error")).run(sockets=[sock]) - # Wait for server to be running - wait_for_server(basic_server_port) - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") +@pytest.fixture +def context_aware_server() -> Generator[str, None, None]: + """Start the context-aware server in a subprocess; yields its base URL.""" + with running_server(run_context_aware_server) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1578,7 +1514,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1602,7 +1538,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1615,7 +1551,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1640,9 +1576,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_aware_server: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -2254,9 +2190,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults(context_aware_server: str) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests @@ -2265,7 +2199,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2285,9 +2222,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_aware_server: str) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", @@ -2296,7 +2231,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820b..bde94fa08a 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,9 +1,7 @@ -"""Common test utilities for MCP server tests.""" - +import multiprocessing import socket import threading -import time -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Any @@ -58,28 +56,27 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) -def wait_for_server(port: int, timeout: float = 20.0) -> None: - """Wait for server to be ready to accept connections. - - Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. - - Args: - port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) +@contextmanager +def running_server(target: Callable[..., None], **server_kwargs: Any) -> Generator[str, None, None]: + """Start `target` in a subprocess and yield the running server's base URL. - Raises: - TimeoutError: If server doesn't start within the timeout period + The child binds its own listening socket and reports the actual port back + through a pipe, so the parent never has to pick (and momentarily free) a + port — eliminating the cross-worker port race under `pytest -n auto`. """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.1) - s.connect(("127.0.0.1", port)) - # Server is ready - return - except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly - time.sleep(0.01) - raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover + reader, writer = multiprocessing.Pipe(duplex=False) + proc = multiprocessing.Process(target=target, kwargs={"port_writer": writer, **server_kwargs}, daemon=True) + proc.start() + # Drop the parent's writer copy so reader.recv() raises EOFError (instead of + # blocking forever) if the child dies before reporting its port. + writer.close() + try: + port = reader.recv() + finally: + reader.close() + + try: + yield f"http://127.0.0.1:{port}" + finally: + proc.kill() + proc.join(timeout=2)