Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/adcp/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ async def get_products(params, context=None):
from __future__ import annotations

from adcp.capabilities import validate_capabilities
from adcp.server._hooks import (
PreValidationHook,
PreValidationHookChain,
PreValidationHookError,
PreValidationHooks,
compose_pre_validation_hooks,
)
from adcp.server.a2a_server import (
ADCPAgentExecutor,
MessageParser,
Expand Down Expand Up @@ -153,10 +160,6 @@ async def get_products(params, context=None):
)
from adcp.server.spec_compat import (
CANONICAL_CREATIVE_AGENT_URL,
PreValidationHook,
PreValidationHookChain,
PreValidationHooks,
compose_pre_validation_hooks,
spec_compat_hooks,
)
from adcp.server.sponsored_intelligence import SponsoredIntelligenceHandler
Expand Down Expand Up @@ -273,6 +276,7 @@ async def get_products(params, context=None):
"CANONICAL_CREATIVE_AGENT_URL",
"PreValidationHook",
"PreValidationHookChain",
"PreValidationHookError",
"PreValidationHooks",
"compose_pre_validation_hooks",
"spec_compat_hooks",
Expand Down
121 changes: 121 additions & 0 deletions src/adcp/server/_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Shared pre-validation hook types, composition, and application.

Single source of truth for the per-tool pre-validation hook machinery used by
both the MCP and A2A transports. :mod:`adcp.server.spec_compat` re-exports the
public names (``PreValidationHook``, ``PreValidationHookChain``,
``PreValidationHooks``, ``compose_pre_validation_hooks``) for backward
compatibility, so adopters can keep importing them from either module.

This module has no transport or framework dependencies — it depends only on the
standard library — so it sits at the bottom of the server import layering and
every other ``adcp.server`` module may import from it freely.
"""

from __future__ import annotations

from collections.abc import Callable, Mapping, Sequence
from typing import Any, TypeAlias

PreValidationHook: TypeAlias = Callable[[str, dict[str, Any]], dict[str, Any]]
"""Callable shape for a pre-validation hook."""

PreValidationHookChain: TypeAlias = PreValidationHook | Sequence[PreValidationHook]
"""One hook or an ordered sequence of hooks for a single tool."""

PreValidationHooks: TypeAlias = dict[str, PreValidationHookChain]
"""Type alias for the ``pre_validation_hooks`` parameter of ``serve()``."""

__all__ = [
"PreValidationHook",
"PreValidationHookChain",
"PreValidationHookError",
"PreValidationHooks",
"compose_pre_validation_hooks",
]


class PreValidationHookError(Exception):
"""A single hook in an ordered pre-validation chain failed.

Carries the zero-based ``index`` of the failing hook within its chain and
the resolved ``hook_name`` so the dispatcher can surface both in the
``INVALID_REQUEST`` message — naming the exact callable instead of a
generic "pre_validation_hook raised ...".
"""

def __init__(self, *, index: int, hook_name: str, message: str) -> None:
self.index = index
self.hook_name = hook_name
super().__init__(f"pre_validation_hook[{index}] {hook_name} {message}")


def _hook_name(hook: PreValidationHook) -> str:
"""Best-effort human name for a hook (function ``__name__`` or class name)."""
name = getattr(hook, "__name__", None)
if isinstance(name, str) and name:
return name
return hook.__class__.__name__


def _flatten_pre_validation_hooks(
hooks: PreValidationHookChain | None,
) -> tuple[PreValidationHook, ...]:
"""Normalize ``None`` / one hook / a sequence into a flat hook tuple."""
if hooks is None:
return ()
if callable(hooks):
return (hooks,)
flattened = tuple(hooks)
for hook in flattened:
if not callable(hook):
raise TypeError("pre-validation hook chains must contain callables")
return flattened


def _apply_pre_validation_hooks(
hooks: tuple[PreValidationHook, ...],
tool_name: str,
params: dict[str, Any],
) -> dict[str, Any]:
"""Run an ordered hook chain, threading each hook's output into the next.

Each hook receives a shallow copy of the running params dict. A hook that
raises, or returns a non-dict, is wrapped in :class:`PreValidationHookError`
naming its chain index and callable.
"""
next_params = params
for index, hook in enumerate(hooks):
hook_name = _hook_name(hook)
try:
next_params = hook(tool_name, dict(next_params))
except Exception as exc:
raise PreValidationHookError(
index=index,
hook_name=hook_name,
message=f"raised {type(exc).__name__}: {exc}",
) from exc
if not isinstance(next_params, dict):
raise PreValidationHookError(
index=index,
hook_name=hook_name,
message=f"returned {type(next_params).__name__}, expected dict",
)
return next_params


def compose_pre_validation_hooks(
*hook_maps: Mapping[str, PreValidationHookChain] | None,
) -> dict[str, tuple[PreValidationHook, ...]]:
"""Compose ordered pre-validation hook maps.

Later maps append to earlier maps for overlapping tool names. Each
tool's hooks run left-to-right, feeding the returned args from one hook
into the next.
"""
composed: dict[str, list[PreValidationHook]] = {}
for hook_map in hook_maps:
if hook_map is None:
continue
for tool_name, chain in hook_map.items():
composed.setdefault(tool_name, []).extend(_flatten_pre_validation_hooks(chain))
return {tool_name: tuple(hooks) for tool_name, hooks in composed.items()}
2 changes: 1 addition & 1 deletion src/adcp/server/a2a_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
from starlette.applications import Starlette

from adcp.exceptions import ADCPError
from adcp.server._hooks import PreValidationHooks
from adcp.server.base import ADCPHandler, ToolContext
from adcp.server.helpers import ResponseEnhancer, _apply_response_enhancer
from adcp.server.spec_compat import PreValidationHooks

# Decisioning-layer ``AdcpError`` (from ``adcp.decisioning.types``) is the
# wire-shaped structured error platform methods raise. It is NOT a subclass
Expand Down
43 changes: 11 additions & 32 deletions src/adcp/server/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
from collections.abc import Callable, Iterable
from typing import Any

from adcp.server._hooks import (
PreValidationHookChain,
PreValidationHookError,
PreValidationHooks,
_apply_pre_validation_hooks,
_flatten_pre_validation_hooks,
)
from adcp.server.base import ADCPHandler, ToolContext
from adcp.server.helpers import ResponseEnhancer, _apply_response_enhancer
from adcp.server.spec_compat import PreValidationHook, PreValidationHookChain
from adcp.server.test_controller import SCENARIOS as _CONTROLLER_SCENARIOS
from adcp.types import (
MEDIA_BUY_LEGACY_STATUS_VALUES,
Expand Down Expand Up @@ -2097,33 +2103,6 @@ class when the annotation is:
return None


def _flatten_pre_validation_hooks(
hooks: PreValidationHookChain | None,
) -> tuple[PreValidationHook, ...]:
if hooks is None:
return ()
if callable(hooks):
return (hooks,)
flattened = tuple(hooks)
for hook in flattened:
if not callable(hook):
raise TypeError("pre-validation hook chains must contain callables")
return flattened


def _apply_pre_validation_hooks(
hooks: tuple[PreValidationHook, ...],
method_name: str,
params: dict[str, Any],
) -> dict[str, Any]:
next_params = params
for hook in hooks:
next_params = hook(method_name, dict(next_params))
if not isinstance(next_params, dict):
raise TypeError("pre-validation hooks must return dict arguments")
return next_params


def _normalize_unknown_field_policy(
policy: UnknownFieldPolicy | str | None,
) -> UnknownFieldPolicy:
Expand Down Expand Up @@ -2336,13 +2315,13 @@ async def call_tool(params: dict[str, Any], context: ToolContext | None = None)
params = _apply_pre_validation_hooks(
pre_validation_hooks, method_name, dict(params)
)
except Exception as exc:
except PreValidationHookError as exc:
raise ADCPTaskError(
operation=method_name,
errors=[
Error(
code="INVALID_REQUEST",
message=f"pre_validation_hook raised {type(exc).__name__}: {exc}",
message=str(exc),
)
],
) from exc
Expand Down Expand Up @@ -2707,7 +2686,7 @@ def __init__(
*,
advertise_all: bool = False,
validation: ValidationHookConfig | None = None,
pre_validation_hooks: dict[str, PreValidationHookChain] | None = None,
pre_validation_hooks: PreValidationHooks | None = None,
response_enhancer: ResponseEnhancer | None = None,
):
"""Create tool set from handler.
Expand Down Expand Up @@ -2777,7 +2756,7 @@ def create_mcp_tools(
*,
advertise_all: bool = False,
validation: ValidationHookConfig | None = None,
pre_validation_hooks: dict[str, PreValidationHookChain] | None = None,
pre_validation_hooks: PreValidationHooks | None = None,
response_enhancer: ResponseEnhancer | None = None,
) -> MCPToolSet:
"""Create MCP tools from an ADCP handler.
Expand Down
2 changes: 1 addition & 1 deletion src/adcp/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def get_adcp_capabilities(self, params, context=None):

logger = logging.getLogger("adcp.server")

from adcp.server._hooks import PreValidationHooks
from adcp.server.base import ADCPHandler, ToolContext
from adcp.server.helpers import ResponseEnhancer
from adcp.server.mcp_sessions import ADCPStreamableHTTPSessionManager
Expand All @@ -36,7 +37,6 @@ async def get_adcp_capabilities(self, params, context=None):
create_tool_caller,
get_tools_for_handler,
)
from adcp.server.spec_compat import PreValidationHooks
from adcp.validation.client_hooks import (
SERVER_DEFAULT_VALIDATION as DEFAULT_VALIDATION,
)
Expand Down
62 changes: 21 additions & 41 deletions src/adcp/server/spec_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,26 @@
from __future__ import annotations

import warnings
from collections.abc import Callable, Collection, Mapping, Sequence
from typing import Any, TypeAlias

PreValidationHook: TypeAlias = Callable[[str, dict[str, Any]], dict[str, Any]]
"""Callable shape for a pre-validation hook."""

PreValidationHookChain: TypeAlias = PreValidationHook | Sequence[PreValidationHook]
"""One hook or an ordered sequence of hooks for a single tool."""
from collections.abc import Callable, Collection
from typing import Any

from adcp.server._hooks import (
PreValidationHook,
PreValidationHookChain,
PreValidationHookError,
PreValidationHooks,
compose_pre_validation_hooks,
)

PreValidationHooks: TypeAlias = dict[str, PreValidationHookChain]
"""Type alias for the ``pre_validation_hooks`` parameter of ``serve()``."""
__all__ = [
"CANONICAL_CREATIVE_AGENT_URL",
"PreValidationHook",
"PreValidationHookChain",
"PreValidationHookError",
"PreValidationHooks",
"compose_pre_validation_hooks",
"spec_compat_hooks",
]

CANONICAL_CREATIVE_AGENT_URL = "https://creative.adcontextprotocol.org"
"""Canonical ``agent_url`` for the AdCP standard creative-format registry.
Expand Down Expand Up @@ -80,35 +89,6 @@
)


def _flatten_hook_chain(chain: PreValidationHookChain) -> tuple[PreValidationHook, ...]:
if callable(chain):
return (chain,)
hooks = tuple(chain)
for hook in hooks:
if not callable(hook):
raise TypeError("pre-validation hook chains must contain callables")
return hooks


def compose_pre_validation_hooks(
*hook_maps: Mapping[str, PreValidationHookChain] | None,
) -> dict[str, tuple[PreValidationHook, ...]]:
"""Compose ordered pre-validation hook maps.

Later maps append to earlier maps for overlapping tool names. Each
tool's hooks run left-to-right, feeding the returned args from one hook
into the next.
"""

composed: dict[str, list[PreValidationHook]] = {}
for hook_map in hook_maps:
if hook_map is None:
continue
for tool_name, chain in hook_map.items():
composed.setdefault(tool_name, []).extend(_flatten_hook_chain(chain))
return {tool_name: tuple(hooks) for tool_name, hooks in composed.items()}


def _hook_get_products(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: # noqa: ARG001
"""Default ``buying_mode`` to ``'brief'`` when omitted.

Expand Down Expand Up @@ -297,8 +277,8 @@ def _spec_compat_hooks_impl(
Adopters who need granular control over the three sub-behaviors
should copy the relevant logic from
``adcp.server.spec_compat._coerce_asset`` / ``_hook_get_products``
rather than trying to layer hooks — ``pre_validation_hooks`` allows
only one callable per tool name.
or compose an ordered hook chain with
:func:`adcp.server.compose_pre_validation_hooks`.

Args:
exclude: Tool names to exclude from the returned dict. Names not in
Expand Down
Loading
Loading