diff --git a/README.md b/README.md index 8be588d..682493f 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ Those instruments can be bootstrapped for: - `fastapi`, - `litestar`, - or `faststream` service, +- or `fastmcp` service, - or even a service that doesn't use one of these frameworks. Interested? Let's dive right in ⚡ @@ -78,6 +79,7 @@ Also, you can specify extras during installation for concrete framework: - `fastapi` - `litestar` - `faststream` (ASGI app) +- `fastmcp` Also we have `granian` extra that is requires for `create_granian_server`. @@ -198,6 +200,28 @@ settings = YourSettings() application: AsgiFastStream = FastStreamBootstrapper(settings).bootstrap() ``` +### FastMCP + +```python +from fastmcp import FastMCP + +from microbootstrap import FastMcpSettings +from microbootstrap.bootstrappers.fastmcp import FastMcpBootstrapper + + +class YourSettings(FastMcpSettings): + service_debug: bool = False + service_name: str = "my-awesome-mcp-service" + service_description: str = "MCP server for internal tools" + + sentry_dsn: str = "your-sentry-dsn" + + +settings = YourSettings() + +application: FastMCP = FastMcpBootstrapper(settings).bootstrap() +``` + ## Settings The settings object is the core of microbootstrap. diff --git a/examples/fastmcp_app.py b/examples/fastmcp_app.py new file mode 100644 index 0000000..23ee6bc --- /dev/null +++ b/examples/fastmcp_app.py @@ -0,0 +1,17 @@ +from fastmcp import FastMCP + +from microbootstrap import FastMcpSettings +from microbootstrap.bootstrappers.fastmcp import FastMcpBootstrapper + + +class Settings(FastMcpSettings): + service_name: str = "example-mcp" + service_description: str = "Example FastMCP service" + + +application: FastMCP = FastMcpBootstrapper(Settings()).bootstrap() + + +@application.tool +def greet_person(person_name: str) -> str: + return f"Hello, {person_name}!" diff --git a/microbootstrap/__init__.py b/microbootstrap/__init__.py index ae33bef..7c8295a 100644 --- a/microbootstrap/__init__.py +++ b/microbootstrap/__init__.py @@ -8,6 +8,7 @@ ) from microbootstrap.instruments.prometheus_instrument import ( FastApiPrometheusConfig, + FastMcpPrometheusConfig, FastStreamPrometheusConfig, FastStreamPrometheusMiddlewareProtocol, LitestarPrometheusConfig, @@ -17,6 +18,7 @@ from microbootstrap.instruments.swagger_instrument import SwaggerConfig from microbootstrap.settings import ( FastApiSettings, + FastMcpSettings, FastStreamSettings, InstrumentsSetupperSettings, LitestarSettings, @@ -27,6 +29,8 @@ "CorsConfig", "FastApiPrometheusConfig", "FastApiSettings", + "FastMcpPrometheusConfig", + "FastMcpSettings", "FastStreamOpentelemetryConfig", "FastStreamPrometheusConfig", "FastStreamPrometheusMiddlewareProtocol", diff --git a/microbootstrap/bootstrappers/fastmcp.py b/microbootstrap/bootstrappers/fastmcp.py new file mode 100644 index 0000000..656c0d1 --- /dev/null +++ b/microbootstrap/bootstrappers/fastmcp.py @@ -0,0 +1,101 @@ +from __future__ import annotations +import typing + +import prometheus_client +import typing_extensions +from fastmcp import FastMCP +from starlette.responses import JSONResponse, Response + +from microbootstrap.bootstrappers.base import ApplicationBootstrapper +from microbootstrap.config.fastmcp import FastMcpConfig +from microbootstrap.instruments.health_checks_instrument import HealthChecksInstrument, HealthCheckTypedDict +from microbootstrap.instruments.logging_instrument import LoggingInstrument +from microbootstrap.instruments.prometheus_instrument import FastMcpPrometheusConfig, PrometheusInstrument +from microbootstrap.instruments.pyroscope_instrument import PyroscopeInstrument +from microbootstrap.instruments.sentry_instrument import SentryInstrument +from microbootstrap.middlewares.fastmcp import FastMcpLoggingMiddleware +from microbootstrap.settings import FastMcpSettings + + +if typing.TYPE_CHECKING: + from starlette.requests import Request + + +class KwargsFastMCP(FastMCP[typing.Any]): + def __init__(self, **kwargs: typing.Any) -> None: # noqa: ANN401 + super().__init__(**kwargs) + + +class FastMcpBootstrapper( + ApplicationBootstrapper[FastMcpSettings, FastMCP[typing.Any], FastMcpConfig], +): + application_config = FastMcpConfig() + application_type = KwargsFastMCP + + def bootstrap_before(self: typing_extensions.Self) -> dict[str, typing.Any]: + return { + "name": self.application_config.name or self.settings.service_name, + "instructions": self.application_config.instructions or self.settings.service_description, + "version": self.application_config.version or self.settings.service_version, + } + + def bootstrap_before_instruments_after_app_created( + self, + application: FastMCP[typing.Any], + ) -> FastMCP[typing.Any]: + self.console_writer.print_bootstrap_table() + return application + + +FastMcpBootstrapper.use_instrument()(SentryInstrument) +FastMcpBootstrapper.use_instrument()(PyroscopeInstrument) + + +@FastMcpBootstrapper.use_instrument() +class FastMcpLoggingInstrument(LoggingInstrument): + def bootstrap_after(self, application: FastMCP[typing.Any]) -> FastMCP[typing.Any]: # type: ignore[override] + if not self.instrument_config.logging_turn_off_middleware: + application.add_middleware(FastMcpLoggingMiddleware()) + return application + + +@FastMcpBootstrapper.use_instrument() +class FastMcpHealthChecksInstrument(HealthChecksInstrument): + def bootstrap_after(self, application: FastMCP[typing.Any]) -> FastMCP[typing.Any]: # type: ignore[override] + @application.custom_route( + self.instrument_config.health_checks_path, + methods=["GET"], + name="health_check", + include_in_schema=self.instrument_config.health_checks_include_in_schema, + ) + async def health_check_handler(request: Request) -> JSONResponse: # noqa: ARG001 + response_data: HealthCheckTypedDict = self.render_health_check_data() + return JSONResponse(response_data) + + return application + + +@FastMcpBootstrapper.use_instrument() +class FastMcpPrometheusInstrument(PrometheusInstrument[FastMcpPrometheusConfig]): + def bootstrap_after(self, application: FastMCP[typing.Any]) -> FastMCP[typing.Any]: # type: ignore[override] + if not self.instrument_config.prometheus_register_route: + return application + + @application.custom_route( + self.instrument_config.prometheus_metrics_path, + methods=["GET"], + name="metrics", + include_in_schema=self.instrument_config.prometheus_metrics_include_in_schema, + ) + async def metrics_handler(request: Request) -> Response: # noqa: ARG001 + registry: typing.Final = self.instrument_config.prometheus_registry or prometheus_client.REGISTRY + return Response( + prometheus_client.generate_latest(registry), + headers={"content-type": prometheus_client.CONTENT_TYPE_LATEST}, + ) + + return application + + @classmethod + def get_config_type(cls) -> type[FastMcpPrometheusConfig]: + return FastMcpPrometheusConfig diff --git a/microbootstrap/bootstrappers/litestar.py b/microbootstrap/bootstrappers/litestar.py index 736bedc..bf827a8 100644 --- a/microbootstrap/bootstrappers/litestar.py +++ b/microbootstrap/bootstrappers/litestar.py @@ -160,8 +160,8 @@ def __init__(self, config: OpenTelemetryConfig) -> None: def create_open_telemetry_middleware(self, app: ASGIApp) -> OpenTelemetryMiddleware: return OpenTelemetryMiddleware( app=app, - client_request_hook=self.config.client_request_hook_handler, # type: ignore[arg-type] - client_response_hook=self.config.client_response_hook_handler, # type: ignore[arg-type] + client_request_hook=self.config.client_request_hook_handler, + client_response_hook=self.config.client_response_hook_handler, default_span_details=build_litestar_route_details_from_scope, excluded_urls=get_excluded_urls(self.config.exclude_urls_env_key), meter=self.config.meter, diff --git a/microbootstrap/config/fastmcp.py b/microbootstrap/config/fastmcp.py new file mode 100644 index 0000000..e51c16a --- /dev/null +++ b/microbootstrap/config/fastmcp.py @@ -0,0 +1,42 @@ +from __future__ import annotations +import dataclasses +import typing + + +if typing.TYPE_CHECKING: + import mcp.types + from fastmcp.client.sampling import SamplingHandler + from fastmcp.server.auth import AuthProvider + from fastmcp.server.lifespan import Lifespan + from fastmcp.server.middleware import Middleware as FastMcpMiddleware + from fastmcp.server.providers import Provider + from fastmcp.server.server import DuplicateBehavior, LifespanCallable + from fastmcp.server.transforms import Transform + from fastmcp.tools.base import Tool + from key_value.aio.protocols import AsyncKeyValue + + +@dataclasses.dataclass +class FastMcpConfig: + name: str | None = None + instructions: str | None = None + version: str | int | float | None = None + website_url: str | None = None + icons: list[mcp.types.Icon] | None = None + auth: AuthProvider | None = None + middleware: typing.Sequence[FastMcpMiddleware] | None = None + providers: typing.Sequence[Provider] | None = None + transforms: typing.Sequence[Transform] | None = None + lifespan: LifespanCallable | Lifespan | None = None + tools: typing.Sequence[Tool | typing.Callable[..., typing.Any]] | None = None + on_duplicate: DuplicateBehavior | None = None + mask_error_details: bool | None = None + dereference_schemas: bool = True + strict_input_validation: bool | None = None + list_page_size: int | None = None + tasks: bool | None = None + session_state_store: AsyncKeyValue | None = None + sampling_handler: SamplingHandler[typing.Any, typing.Any] | None = None + sampling_handler_behavior: typing.Literal["always", "fallback"] | None = None + client_log_level: mcp.types.LoggingLevel | None = None + experimental_capabilities: dict[str, dict[str, typing.Any]] | None = None diff --git a/microbootstrap/instruments/prometheus_instrument.py b/microbootstrap/instruments/prometheus_instrument.py index 9936657..7d1f290 100644 --- a/microbootstrap/instruments/prometheus_instrument.py +++ b/microbootstrap/instruments/prometheus_instrument.py @@ -32,6 +32,11 @@ class FastApiPrometheusConfig(BasePrometheusConfig): prometheus_custom_labels: dict[str, typing.Any] = pydantic.Field(default_factory=dict) +class FastMcpPrometheusConfig(BasePrometheusConfig): + prometheus_registry: typing.Any | None = None + prometheus_register_route: bool = True + + @typing.runtime_checkable class FastStreamPrometheusMiddlewareProtocol(typing.Protocol): def __init__( diff --git a/microbootstrap/middlewares/fastmcp.py b/microbootstrap/middlewares/fastmcp.py new file mode 100644 index 0000000..488aa40 --- /dev/null +++ b/microbootstrap/middlewares/fastmcp.py @@ -0,0 +1,46 @@ +from __future__ import annotations +import time +import typing + +import structlog +from fastmcp.server.middleware import Middleware, MiddlewareContext + + +if typing.TYPE_CHECKING: + from fastmcp.server.middleware import CallNext + + +fastmcp_access_logger: typing.Final = structlog.get_logger("mcp.access") + + +class FastMcpLoggingMiddleware(Middleware): + async def on_message( + self, + context: MiddlewareContext[typing.Any], + call_next: CallNext[typing.Any, typing.Any], + ) -> typing.Any: # noqa: ANN401 + start_time: typing.Final = time.perf_counter_ns() + try: + result: typing.Final = await call_next(context) + except Exception: + fastmcp_access_logger.exception( + context.method or "unknown", + mcp={ + "method": context.method, + "source": context.source, + "type": context.type, + }, + duration=time.perf_counter_ns() - start_time, + ) + raise + + fastmcp_access_logger.info( + context.method or "unknown", + mcp={ + "method": context.method, + "source": context.source, + "type": context.type, + }, + duration=time.perf_counter_ns() - start_time, + ) + return result diff --git a/microbootstrap/settings.py b/microbootstrap/settings.py index a6a033a..5f9bce8 100644 --- a/microbootstrap/settings.py +++ b/microbootstrap/settings.py @@ -8,6 +8,7 @@ from microbootstrap import ( CorsConfig, FastApiPrometheusConfig, + FastMcpPrometheusConfig, FastStreamOpentelemetryConfig, FastStreamPrometheusConfig, HealthChecksConfig, @@ -102,6 +103,18 @@ class FastStreamSettings( # type: ignore[misc] asyncapi_path: str | None = "/asyncapi" +class FastMcpSettings( # type: ignore[misc] + BaseServiceSettings, + ServerConfig, + LoggingConfig, + SentryConfig, + FastMcpPrometheusConfig, + HealthChecksConfig, + PyroscopeConfig, +): + """Settings for a fastmcp bootstrap.""" + + class InstrumentsSetupperSettings( # type: ignore[misc] BaseServiceSettings, LoggingConfig, diff --git a/pyproject.toml b/pyproject.toml index 65b7dbf..362fd1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ litestar = [ ] granian = ["granian[reload]>=1"] faststream = ["faststream~=0.6.2", "prometheus-client>=0.20"] +fastmcp = ["fastmcp>=2,<4", "prometheus-client>=0.20"] [dependency-groups] dev = [ diff --git a/tests/bootstrappers/test_fastmcp.py b/tests/bootstrappers/test_fastmcp.py new file mode 100644 index 0000000..22c6af3 --- /dev/null +++ b/tests/bootstrappers/test_fastmcp.py @@ -0,0 +1,150 @@ +import typing + +import prometheus_client +from fastmcp import FastMCP +from starlette import status +from starlette.testclient import TestClient + +from microbootstrap.bootstrappers.fastmcp import FastMcpBootstrapper +from microbootstrap.config.fastmcp import FastMcpConfig +from microbootstrap.instruments.health_checks_instrument import HealthChecksConfig +from microbootstrap.instruments.logging_instrument import LoggingConfig +from microbootstrap.instruments.prometheus_instrument import FastMcpPrometheusConfig +from microbootstrap.middlewares.fastmcp import FastMcpLoggingMiddleware +from microbootstrap.settings import FastMcpSettings + + +def test_fastmcp_bootstrap_uses_service_metadata() -> None: + test_settings: typing.Final = FastMcpSettings( + service_name="test-mcp", + service_description="Test MCP service", + service_version="2.0.0", + ) + + application: typing.Final = FastMcpBootstrapper(test_settings).bootstrap() + + assert isinstance(application, FastMCP) + assert application.name == test_settings.service_name + assert application.instructions == test_settings.service_description + assert application.version == test_settings.service_version + + +def test_fastmcp_configure_application_overrides_defaults() -> None: + test_instructions: typing.Final = "Configured instructions" + + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_application(FastMcpConfig(instructions=test_instructions)) + .bootstrap() + ) + + assert application.instructions == test_instructions + + +def test_fastmcp_configure_instrument() -> None: + bootstrapper: typing.Final = FastMcpBootstrapper(FastMcpSettings()).configure_instrument( + LoggingConfig(logging_enabled=False), + ) + + application: typing.Final = bootstrapper.bootstrap() + + assert isinstance(application, FastMCP) + + +def test_fastmcp_logging_adds_mcp_middleware() -> None: + application: typing.Final = FastMcpBootstrapper(FastMcpSettings()).bootstrap() + + assert any(isinstance(middleware, FastMcpLoggingMiddleware) for middleware in application.middleware) + + +def test_fastmcp_logging_middleware_can_be_disabled() -> None: + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_instrument(LoggingConfig(logging_turn_off_middleware=True)) + .bootstrap() + ) + + assert not any(isinstance(middleware, FastMcpLoggingMiddleware) for middleware in application.middleware) + + +def test_fastmcp_http_app_is_configured_through_fastmcp_interface() -> None: + application: typing.Final = FastMcpBootstrapper(FastMcpSettings()).bootstrap() + + http_application: typing.Final = application.http_app(path="/api/mcp/", transport="http") + + assert any(getattr(route, "path", None) == "/api/mcp/" for route in http_application.routes) + + +def test_fastmcp_health_checks() -> None: + test_health_path: typing.Final = "/test-health/" + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_instrument(HealthChecksConfig(health_checks_path=test_health_path)) + .bootstrap() + ) + + response: typing.Final = TestClient(application.http_app()).get(test_health_path) + + assert response.status_code == status.HTTP_200_OK + assert response.json()["health_status"] is True + + +def test_fastmcp_health_checks_route_can_be_disabled_with_existing_enabled_flag() -> None: + test_health_path: typing.Final = "/test-health/" + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_instrument( + HealthChecksConfig( + health_checks_path=test_health_path, + health_checks_enabled=False, + ), + ) + .bootstrap() + ) + + response: typing.Final = TestClient(application.http_app()).get(test_health_path) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +def test_fastmcp_prometheus() -> None: + test_metrics_path: typing.Final = "/test-metrics" + metrics_registry: typing.Final = prometheus_client.CollectorRegistry() + prometheus_client.Counter( + "fastmcp_test_requests_total", + "FastMCP test requests.", + registry=metrics_registry, + ).inc() + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_instrument( + FastMcpPrometheusConfig( + prometheus_metrics_path=test_metrics_path, + prometheus_registry=metrics_registry, + ), + ) + .bootstrap() + ) + + response: typing.Final = TestClient(application.http_app()).get(test_metrics_path) + + assert response.status_code == status.HTTP_200_OK + assert b"fastmcp_test_requests_total 1.0" in response.content + + +def test_fastmcp_prometheus_route_can_be_disabled() -> None: + test_metrics_path: typing.Final = "/test-metrics" + application: typing.Final = ( + FastMcpBootstrapper(FastMcpSettings()) + .configure_instrument( + FastMcpPrometheusConfig( + prometheus_metrics_path=test_metrics_path, + prometheus_register_route=False, + ), + ) + .bootstrap() + ) + + response: typing.Final = TestClient(application.http_app()).get(test_metrics_path) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/middlewares/__init__.py b/tests/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/middlewares/test_fastmcp.py b/tests/middlewares/test_fastmcp.py new file mode 100644 index 0000000..7cca9d1 --- /dev/null +++ b/tests/middlewares/test_fastmcp.py @@ -0,0 +1,52 @@ +import typing +from unittest.mock import MagicMock + +import pytest +from fastmcp.server.middleware import MiddlewareContext + +from microbootstrap.middlewares.fastmcp import FastMcpLoggingMiddleware + + +async def test_fastmcp_logging_middleware_logs_success(monkeypatch: pytest.MonkeyPatch) -> None: + fake_logger: typing.Final = MagicMock() + middleware: typing.Final = FastMcpLoggingMiddleware() + middleware_context: typing.Final = MiddlewareContext( + message={"payload": "test"}, + method="tools/list", + source="client", + type="request", + ) + + async def call_next(context: MiddlewareContext[typing.Any]) -> dict[str, str]: + assert context is middleware_context + return {"status": "ok"} + + monkeypatch.setattr("microbootstrap.middlewares.fastmcp.fastmcp_access_logger", fake_logger) + + result: typing.Final = await middleware.on_message(middleware_context, call_next) + + assert result == {"status": "ok"} + fake_logger.info.assert_called_once() + + +async def test_fastmcp_logging_middleware_logs_exception(monkeypatch: pytest.MonkeyPatch) -> None: + fake_logger: typing.Final = MagicMock() + middleware: typing.Final = FastMcpLoggingMiddleware() + middleware_context: typing.Final = MiddlewareContext( + message={"payload": "test"}, + method="tools/call", + source="client", + type="request", + ) + + async def call_next(context: MiddlewareContext[typing.Any]) -> dict[str, str]: + assert context is middleware_context + msg = "MCP call failed" + raise RuntimeError(msg) + + monkeypatch.setattr("microbootstrap.middlewares.fastmcp.fastmcp_access_logger", fake_logger) + + with pytest.raises(RuntimeError, match="MCP call failed"): + await middleware.on_message(middleware_context, call_next) + + fake_logger.exception.assert_called_once()