Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 15 additions & 30 deletions tests/client/test_http_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
(server→client and client→server) using the streamable HTTP transport.
"""

import multiprocessing
import socket
from collections.abc import AsyncGenerator, Generator
from contextlib import asynccontextmanager
from multiprocessing.connection import Connection

import pytest
from starlette.applications import Starlette
Expand All @@ -19,7 +19,7 @@
from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import TextContent, Tool
from tests.test_helpers import wait_for_server
from tests.test_helpers import running_server

# Test constants with various Unicode characters
UNICODE_TEST_STRINGS = {
Expand All @@ -41,7 +41,7 @@
}


def run_unicode_server(port: int) -> None: # pragma: no cover
def run_unicode_server(port_writer: Connection) -> None: # pragma: no cover
"""Run the Unicode test server in a separate process."""
import uvicorn

Expand Down Expand Up @@ -137,43 +137,28 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
lifespan=lifespan,
)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", 0))
sock.listen()
port = sock.getsockname()[1]
port_writer.send(port)
port_writer.close()

# Run the server
config = uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="error",
)
uvicorn_server = uvicorn.Server(config)
uvicorn_server.run()


@pytest.fixture
def unicode_server_port() -> int:
"""Find an available port for the Unicode test server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
uvicorn_server.run(sockets=[sock])


@pytest.fixture
def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]:
def running_unicode_server() -> Generator[str, None, None]:
"""Start a Unicode test server in a separate process."""
proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True)
proc.start()

# Wait for server to be ready
wait_for_server(unicode_server_port)

try:
yield f"http://127.0.0.1:{unicode_server_port}"
finally:
# Clean up - try graceful termination first
proc.terminate()
proc.join(timeout=2)
if proc.is_alive(): # pragma: no cover
proc.kill()
proc.join(timeout=1)
with running_server(run_unicode_server) as url:
yield url


@pytest.mark.anyio
Expand Down
102 changes: 55 additions & 47 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import socket
from multiprocessing.connection import Connection

import httpx
import pytest
Expand All @@ -16,24 +17,11 @@
from mcp.server.sse import SseServerTransport
from mcp.server.transport_security import TransportSecuritySettings
from mcp.types import Tool
from tests.test_helpers import wait_for_server

logger = logging.getLogger(__name__)
SERVER_NAME = "test_sse_security_server"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


@pytest.fixture
def server_url(server_port: int) -> str: # pragma: no cover
return f"http://127.0.0.1:{server_port}"


class SecurityTestServer(Server): # pragma: no cover
def __init__(self):
super().__init__(SERVER_NAME)
Expand All @@ -42,7 +30,9 @@ async def on_list_tools(self) -> list[Tool]:
return []


def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover
def run_server_with_settings(
port_writer: Connection, security_settings: TransportSecuritySettings | None = None
): # pragma: no cover
"""Run the SSE server with specified security settings."""
app = SecurityTestServer()
sse_transport = SseServerTransport("/messages/", security_settings)
Expand All @@ -63,47 +53,65 @@ async def handle_sse(request: Request):
]

starlette_app = Starlette(routes=routes)
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", 0))
sock.listen()
port = sock.getsockname()[1]
port_writer.send(port)
port_writer.close()

server = uvicorn.Server(config=uvicorn.Config(app=starlette_app, log_level="error"))
server.run(sockets=[sock])

def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):

def start_server_process(
security_settings: TransportSecuritySettings | None = None,
) -> tuple[multiprocessing.Process, int]:
"""Start server in a separate process."""
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
reader, writer = multiprocessing.Pipe(duplex=False)
process = multiprocessing.Process(
target=run_server_with_settings,
kwargs={"port_writer": writer, "security_settings": security_settings},
)
process.start()
# Wait for server to be ready to accept connections
wait_for_server(port)
return process
writer.close()
try:
port = reader.recv()
finally:
reader.close()
return process, port


@pytest.mark.anyio
async def test_sse_security_default_settings(server_port: int):
async def test_sse_security_default_settings():
"""Test SSE with default security settings (protection disabled)."""
process = start_server_process(server_port)
process, port = start_server_process()

try:
headers = {"Host": "evil.com", "Origin": "http://evil.com"}

async with httpx.AsyncClient(timeout=5.0) as client:
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
assert response.status_code == 200
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_sse_security_invalid_host_header(server_port: int):
async def test_sse_security_invalid_host_header():
"""Test SSE with invalid Host header."""
# Enable security by providing settings with an empty allowed_hosts list
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
process = start_server_process(server_port, security_settings)
process, port = start_server_process(security_settings)

try:
# Test with invalid host header
headers = {"Host": "evil.com"}

async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"

Expand All @@ -113,20 +121,20 @@ async def test_sse_security_invalid_host_header(server_port: int):


@pytest.mark.anyio
async def test_sse_security_invalid_origin_header(server_port: int):
async def test_sse_security_invalid_origin_header():
"""Test SSE with invalid Origin header."""
# Configure security to allow the host but restrict origins
security_settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
)
process = start_server_process(server_port, security_settings)
process, port = start_server_process(security_settings)

try:
# Test with invalid origin header
headers = {"Origin": "http://evil.com"}

async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
assert response.status_code == 403
assert response.text == "Invalid Origin header"

Expand All @@ -136,20 +144,20 @@ async def test_sse_security_invalid_origin_header(server_port: int):


@pytest.mark.anyio
async def test_sse_security_post_invalid_content_type(server_port: int):
async def test_sse_security_post_invalid_content_type():
"""Test POST endpoint with invalid Content-Type header."""
# Configure security to allow the host
security_settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
)
process = start_server_process(server_port, security_settings)
process, port = start_server_process(security_settings)

try:
async with httpx.AsyncClient(timeout=5.0) as client:
# Test POST with invalid content type
fake_session_id = "12345678123456781234567812345678"
response = await client.post(
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
headers={"Content-Type": "text/plain"},
content="test",
)
Expand All @@ -158,7 +166,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):

# Test POST with missing content type
response = await client.post(
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test"
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", content="test"
)
assert response.status_code == 400
assert response.text == "Invalid Content-Type header"
Expand All @@ -169,18 +177,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int):


@pytest.mark.anyio
async def test_sse_security_disabled(server_port: int):
async def test_sse_security_disabled():
"""Test SSE with security disabled."""
settings = TransportSecuritySettings(enable_dns_rebinding_protection=False)
process = start_server_process(server_port, settings)
process, port = start_server_process(settings)

try:
# Test with invalid host header - should still work
headers = {"Host": "evil.com"}

async with httpx.AsyncClient(timeout=5.0) as client:
# For SSE endpoints, we need to use stream to avoid timeout
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
# Should connect successfully even with invalid host
assert response.status_code == 200

Expand All @@ -190,30 +198,30 @@ async def test_sse_security_disabled(server_port: int):


@pytest.mark.anyio
async def test_sse_security_custom_allowed_hosts(server_port: int):
async def test_sse_security_custom_allowed_hosts():
"""Test SSE with custom allowed hosts."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["localhost", "127.0.0.1", "custom.host"],
allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"],
)
process = start_server_process(server_port, settings)
process, port = start_server_process(settings)

try:
# Test with custom allowed host
headers = {"Host": "custom.host"}

async with httpx.AsyncClient(timeout=5.0) as client:
# For SSE endpoints, we need to use stream to avoid timeout
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
# Should connect successfully with custom host
assert response.status_code == 200

# Test with non-allowed host
headers = {"Host": "evil.com"}

async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"

Expand All @@ -223,14 +231,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):


@pytest.mark.anyio
async def test_sse_security_wildcard_ports(server_port: int):
async def test_sse_security_wildcard_ports():
"""Test SSE with wildcard port patterns."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["localhost:*", "127.0.0.1:*"],
allowed_origins=["http://localhost:*", "http://127.0.0.1:*"],
)
process = start_server_process(server_port, settings)
process, port = start_server_process(settings)

try:
# Test with various port numbers
Expand All @@ -239,15 +247,15 @@ async def test_sse_security_wildcard_ports(server_port: int):

async with httpx.AsyncClient(timeout=5.0) as client:
# For SSE endpoints, we need to use stream to avoid timeout
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
# Should connect successfully with any port
assert response.status_code == 200

headers = {"Origin": f"http://localhost:{test_port}"}

async with httpx.AsyncClient(timeout=5.0) as client:
# For SSE endpoints, we need to use stream to avoid timeout
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
# Should connect successfully with any port
assert response.status_code == 200

Expand All @@ -257,13 +265,13 @@ async def test_sse_security_wildcard_ports(server_port: int):


@pytest.mark.anyio
async def test_sse_security_post_valid_content_type(server_port: int):
async def test_sse_security_post_valid_content_type():
"""Test POST endpoint with valid Content-Type headers."""
# Configure security to allow the host
security_settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
)
process = start_server_process(server_port, security_settings)
process, port = start_server_process(security_settings)

try:
async with httpx.AsyncClient() as client:
Expand All @@ -279,7 +287,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
# Use a valid UUID format (even though session won't exist)
fake_session_id = "12345678123456781234567812345678"
response = await client.post(
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
headers={"Content-Type": content_type},
json={"test": "data"},
)
Expand Down
Loading
Loading