diff --git a/.gitignore b/.gitignore index 245f859..f71a18f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .env .claude +.kiro .sessions .converter diff --git a/README.md b/README.md index 3a4e528..1f16de2 100644 --- a/README.md +++ b/README.md @@ -387,7 +387,7 @@ async def main(): asyncio.run(main()) ``` -Event types: `AGENT_START`, `TOKEN`, `REASONING`, `TOOL_START`, `TOOL_END`, `INTERRUPT`, `NODE_START`, `NODE_STOP`, `HANDOFF`, `COMPLETE`, `MULTIAGENT_START`, `MULTIAGENT_COMPLETE`, `ERROR` — each carrying `{type, agent_name, timestamp, data}`. Enough for a real-time frontend, a log aggregator, or a debugging dashboard. The `AnsiRenderer` gives you coloured terminal output out of the box — agent names, tool calls, reasoning traces, all streaming live. +Event types: `AGENT_START`, `TOKEN`, `REASONING`, `TOOL_START`, `TOOL_END`, `INTERRUPT`, `NODE_START`, `NODE_STOP`, `HANDOFF`, `COMPLETE`, `MULTIAGENT_START`, `MULTIAGENT_COMPLETE`, `ERROR`, `SESSION_START`, `SESSION_END` — each carrying `{type, agent_name, timestamp, data}`. Enough for a real-time frontend, a log aggregator, or a debugging dashboard. The `AnsiRenderer` gives you coloured terminal output out of the box — agent names, tool calls, reasoning traces, all streaming live. --- diff --git a/docs/configuration/Chapter_15.md b/docs/configuration/Chapter_15.md index baffeef..4dac43c 100644 --- a/docs/configuration/Chapter_15.md +++ b/docs/configuration/Chapter_15.md @@ -32,6 +32,21 @@ asyncio.run(main()) ## Event Types +Every event is a `StreamEvent` dataclass with four fields: `type`, `agent_name`, `timestamp`, and `data`. + +### Session lifecycle events + +These two events bracket every invocation. They are produced by the queue layer, not by individual agents. + +| Event Type | Description | `data` payload | +|------------|-------------|----------------| +| `SESSION_START` | First event on the queue — emitted before any agent activity | Serialised `SessionManifest`: agents, orchestrations, entry point, model info, session manager locations | +| `SESSION_END` | Last typed event before the stream closes | `{"session_id": ""}` | + +The `SESSION_START` payload is the full wired topology at invocation time. Use it to restore conversation history, render an architecture diagram, or audit which models are in use — before any agent has run. + +### Per-agent events + | Event Type | Description | |------------|-------------| | `AGENT_START` | Agent begins processing | @@ -42,6 +57,11 @@ asyncio.run(main()) | `INTERRUPT` | Agent pauses for human input | | `COMPLETE` | Agent finishes (with usage metrics) | | `ERROR` | Model or execution error | + +### Multi-agent events + +| Event Type | Description | +|------------|-------------| | `NODE_START` | Graph/swarm node begins | | `NODE_STOP` | Graph/swarm node completes | | `HANDOFF` | Swarm agent hands off to another | @@ -62,15 +82,30 @@ while (event := await queue.get()) is not None: # Send to websocket, log to file, push to metrics system... ``` +A typical consumer pattern that handles the session lifecycle: + +```python +while (event := await queue.get()) is not None: + if event.type == "session_start": + manifest = event.data # full topology snapshot + entry = manifest["entry"] # {"name": "...", "kind": "agent|orchestration"} + elif event.type == "session_end": + session_id = event.data.get("session_id") + else: + # per-agent or multi-agent event + process(event) +``` + ## Configuring the Queue in YAML Event streaming is configured in Python, not YAML — it's a runtime concern. But the **hooks** it installs (`EventPublisher`) listen to the same lifecycle events as your YAML-defined hooks. They coexist peacefully. > **Tips & Tricks** > -> - Call `wire_event_queue()` only **once** per `ResolvedConfig` — it mutates the agents by adding hooks. -> - Call `queue.flush()` between requests to clear stale events from a previous invocation. +> - Call `wire_event_queue()` only **once** per `ResolvedConfig` — it mutates agents and orchestrators by adding hooks. Calling it twice would double-attach publishers. +> - Call `queue.flush()` between requests to clear stale events from a previous invocation. This also resets the `SESSION_START` / `SESSION_END` guards so the next cycle can re-emit them. > - The queue has a max size of 10,000. If your agent generates more events than the consumer processes, events are dropped with a warning. +> - `SESSION_START` is emitted synchronously by `wire_event_queue()` before any agent runs. `SESSION_END` is emitted by `queue.close()` — always call it in a `finally` block. --- diff --git a/examples/12_streaming/README.md b/examples/12_streaming/README.md index ab77bf6..81a33af 100644 --- a/examples/12_streaming/README.md +++ b/examples/12_streaming/README.md @@ -4,7 +4,7 @@ ## What this shows -- `wire_event_queue()` — wire all agents to a single async queue that emits `StreamEvent`s +- `wire_event_queue()` — wire all agents and orchestrators to a single async queue that emits `StreamEvent`s - `AnsiRenderer` — built-in terminal renderer that prints events with colours as they arrive - How strands-compose turns agent lifecycle events into a consumable stream — the same mechanism that powers SSE endpoints, WebSocket feeds, and audit logs @@ -18,9 +18,9 @@ resolved = load("config.yaml") queue = resolved.wire_event_queue() ``` -`resolved.wire_event_queue()` installs an `EventPublisher` hook on every agent. As the agent runs, -the hook converts lifecycle events (tokens, tool calls, completions) into `StreamEvent` -objects and pushes them to the queue. Your consumer loop is simple: +`resolved.wire_event_queue()` installs an `EventPublisher` hook on every agent and orchestrator. +As the session runs, hooks convert lifecycle events (tokens, tool calls, completions) into +`StreamEvent` objects and push them to the queue. Your consumer loop is simple: ```python renderer = AnsiRenderer() @@ -31,17 +31,25 @@ renderer.flush() ### Event types -| Type | When it fires | -|------|---------------| -| `agent_start` | Agent begins processing | -| `token` | Streaming text chunk | -| `reasoning` | Streaming reasoning chunk | -| `tool_start` | Tool call begins | -| `tool_end` | Tool call finished | -| `interrupt` | Agent pauses for human input | -| `complete` | Agent finished (includes token usage) | -| `node_start` / `node_stop` | Swarm / Graph enters/leaves a node | -| `handoff` | Swarm transfers control | +Every invocation produces a `SESSION_START` as the first event and `SESSION_END` as the last, +bracketing all per-agent activity. + +| Type | When it fires | `data` | +|------|---------------|--------| +| `session_start` | Before any agent runs — first event on the queue | Serialised `SessionManifest` (agents, orchestrations, entry, model info) | +| `agent_start` | Agent begins processing | — | +| `token` | Streaming text chunk | `{"text": "..."}` | +| `reasoning` | Streaming reasoning chunk | `{"text": "..."}` | +| `tool_start` | Tool call begins | tool name, input | +| `tool_end` | Tool call finished | tool name, status, result | +| `interrupt` | Agent pauses for human input | interrupt id, reason | +| `complete` | Agent finished (includes token usage) | usage metrics | +| `error` | Model or execution error | exception type, message | +| `node_start` / `node_stop` | Swarm / Graph enters/leaves a node | node id | +| `handoff` | Swarm transfers control | from/to node ids | +| `multiagent_start` | Multi-agent orchestration begins | — | +| `multiagent_complete` | Multi-agent orchestration completes | — | +| `session_end` | After all agent events — last typed event | `{"session_id": ""}` | ## Good to know @@ -53,7 +61,10 @@ consume the queue and convert events to SSE chunks (see `OpenAIStreamConverter`) NDJSON (`RawStreamConverter`). **`queue.flush()`** resets the queue between turns so events from one invocation -don't leak into the next. +don't leak into the next. It also resets the `session_start` / `session_end` guards. + +**`queue.close()`** emits `session_end` then signals end-of-stream. Always call it in +a `finally` block so `session_end` is guaranteed even when an exception occurs. ## Prerequisites diff --git a/examples/12_streaming/main.py b/examples/12_streaming/main.py index a64bae5..ab1ce6c 100644 --- a/examples/12_streaming/main.py +++ b/examples/12_streaming/main.py @@ -21,7 +21,6 @@ async def _stream(prompt: str, entry, queue): """Invoke the entry agent and render the event stream.""" - queue.flush() result = None async def _invoke() -> None: @@ -46,7 +45,6 @@ async def _main() -> None: resolved = load(CONFIG) entry = resolved.entry queue = resolved.wire_event_queue() - print(f"\n{52 * '-'}") print(f"Try: {STARTER}\n") print("researcher -> analyst -> coordinator (with live streaming)") diff --git a/src/strands_compose/config/resolvers/config.py b/src/strands_compose/config/resolvers/config.py index 4b6248d..cf74693 100644 --- a/src/strands_compose/config/resolvers/config.py +++ b/src/strands_compose/config/resolvers/config.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from ...manifest import build_manifest, first_session_id from ...mcp.lifecycle import MCPLifecycle from ...wire import make_event_queue from .mcp import resolve_mcp_client, resolve_mcp_server @@ -47,9 +48,19 @@ def wire_event_queue( ) -> EventQueue: """Wire all agents and orchestrators for event streaming. - This is the recommended way to set up event streaming. It calls - :func:`~strands_compose.wire.make_event_queue` with this - config's agents and orchestrators. + This is the recommended way to set up event streaming. It: + + 1. Builds a :class:`~strands_compose.types.SessionManifest` from the + resolved runtime objects. + 2. Wires every agent (and orchestrator) with an + :class:`~strands_compose.hooks.EventPublisher` via + :func:`~strands_compose.wire.make_event_queue`. + 3. Emits a SESSION_START event carrying the manifest as the first + event on the queue. + + The effective session id is the first non-``None`` ``session_id`` + found in the manifest (agents first, then orchestrations); it is + included in the SESSION_END event payload. .. warning:: @@ -58,16 +69,25 @@ def wire_event_queue( Call it only once per ``ResolvedConfig`` instance. Args: - tool_labels: Optional tool name -> display label mapping. + tool_labels: Optional tool name → display label mapping. Returns: - A ready-to-use :class:`~strands_compose.wire.EventQueue`. + A ready-to-use :class:`~strands_compose.wire.EventQueue` with + SESSION_START already on it. + + Raises: + ValueError: If the entry node cannot be resolved by object identity. """ - return make_event_queue( + manifest = build_manifest(self.agents, self.orchestrators, self.entry) + event_queue = make_event_queue( self.agents, orchestrators=self.orchestrators, tool_labels=tool_labels, + entry_name=manifest.entry.name, + session_id=first_session_id(manifest), ) + event_queue.emit_session_start(manifest) + return event_queue @dataclass diff --git a/src/strands_compose/manifest.py b/src/strands_compose/manifest.py new file mode 100644 index 0000000..49aebcd --- /dev/null +++ b/src/strands_compose/manifest.py @@ -0,0 +1,310 @@ +"""Build session manifests from resolved runtime objects. + +A :class:`~strands_compose.types.SessionManifest` describes the wired session +topology, model/provider info, and storage locations. It is constructed from +runtime ``strands.Agent``, ``Swarm``, ``Graph``, and ``SessionManager`` +instances at invocation time and serialised into the SESSION_START event +payload. + +Both public functions in this module are pure: no I/O, no network calls, no +mutation of inputs. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from strands import Agent +from strands.multiagent import Swarm +from strands.multiagent.graph import Graph +from strands.session import FileSessionManager, S3SessionManager + +from .types import ( + AgentCoreProviderDescriptor, + AgentDescriptor, + CustomProviderDescriptor, + EdgeRef, + EntryDescriptor, + FileProviderDescriptor, + ModelDescriptor, + NodeRef, + OrchestrationDescriptor, + S3ProviderDescriptor, + SessionManagerDescriptor, + SessionManifest, +) + +if TYPE_CHECKING: + from strands.session import SessionManager + + from .types import Node + + +logger = logging.getLogger(__name__) + + +# ── Session manager descriptors ────────────────────────────────────────────── + + +def build_session_manager_descriptor(manager: SessionManager) -> SessionManagerDescriptor: + """Build a :class:`SessionManagerDescriptor` from a SessionManager instance. + + Selection order: + + 1. :class:`strands.session.FileSessionManager` → ``FileProviderDescriptor`` + 2. :class:`strands.session.S3SessionManager` → ``S3ProviderDescriptor`` + 3. Duck-typed AgentCore (``manager.config`` carries ``memory_id``, + ``actor_id``, ``session_id``) → ``AgentCoreProviderDescriptor`` + 4. Fallback → ``CustomProviderDescriptor`` + + The function is pure: no I/O, no mutation. It reads only public attributes + on the manager and the class's ``__module__``/``__qualname__`` for the + custom fallback. + + Args: + manager: A strands ``SessionManager`` instance. + + Returns: + A :class:`SessionManagerDescriptor` (one of the four concrete subtypes). + + Raises: + AttributeError: If a required attribute is missing on a ``File`` or + ``S3`` manager (indicates a broken strands install). + """ + if isinstance(manager, FileSessionManager): + return FileProviderDescriptor( + provider="file", + session_id=manager.session_id, + storage_dir=manager.storage_dir, + ) + + if isinstance(manager, S3SessionManager): + return S3ProviderDescriptor( + provider="s3", + session_id=manager.session_id, + bucket=manager.bucket, + prefix=manager.prefix, + ) + + config = getattr(manager, "config", None) + if config is not None and all( + hasattr(config, attr) for attr in ("memory_id", "actor_id", "session_id") + ): + return AgentCoreProviderDescriptor( + provider="agentcore", + session_id=str(config.session_id), + memory_id=str(config.memory_id), + actor_id=str(config.actor_id), + ) + + return CustomProviderDescriptor( + provider="custom", + session_id=getattr(manager, "session_id", None), + class_name=f"{type(manager).__module__}.{type(manager).__qualname__}", + ) + + +def _descriptor_or_none( + manager: SessionManager | None, +) -> SessionManagerDescriptor | None: + """Return the descriptor for *manager*, or ``None`` if no manager is set.""" + if manager is None: + return None + return build_session_manager_descriptor(manager) + + +# ── Agent descriptor ───────────────────────────────────────────────────────── + + +def _model_descriptor(agent: Agent) -> ModelDescriptor: + """Extract a :class:`ModelDescriptor` from an agent's model.""" + config = agent.model.get_config() + if isinstance(config, dict): + model_id = config.get("model_id") + else: + model_id = getattr(config, "model_id", None) + return ModelDescriptor( + model_id=model_id, + provider=f"{type(agent.model).__module__}.{type(agent.model).__qualname__}", + ) + + +def _agent_descriptor(name: str, agent: Agent) -> AgentDescriptor: + """Build an :class:`AgentDescriptor` for a single resolved agent.""" + return AgentDescriptor( + name=name, + description=agent.description, + model=_model_descriptor(agent), + session_manager=_descriptor_or_none(agent._session_manager), + ) + + +# ── Orchestration descriptor ───────────────────────────────────────────────── + + +def _swarm_topology(swarm: Swarm) -> tuple[list[NodeRef], None, str | None]: + """Extract (nodes, edges, entry_node_id) from a Swarm. + + Swarm handoffs are dynamic, so ``edges`` is always ``None``. + """ + nodes = [NodeRef(id=node.node_id, kind="agent") for node in swarm.nodes.values()] + + entry_id: str | None = None + if swarm.entry_point is not None: + for node in swarm.nodes.values(): + if node.executor is swarm.entry_point: + entry_id = node.node_id + break + elif swarm.nodes: + entry_id = next(iter(swarm.nodes.values())).node_id + + return nodes, None, entry_id + + +def _graph_topology(graph: Graph) -> tuple[list[NodeRef], list[EdgeRef], str | None]: + """Extract (nodes, edges, entry_node_id) from a Graph.""" + nodes = [ + NodeRef( + id=node.node_id, + kind="agent" if isinstance(node.executor, Agent) else "orchestration", + ) + for node in graph.nodes.values() + ] + edges = [ + EdgeRef(from_id=edge.from_node.node_id, to_id=edge.to_node.node_id) for edge in graph.edges + ] + + entry_points = list(graph.entry_points) + if len(entry_points) == 1: + entry_id: str | None = entry_points[0].node_id + elif len(entry_points) > 1: + entry_id = ",".join(node.node_id for node in entry_points) + else: + entry_id = None + + return nodes, edges, entry_id + + +def _orchestration_descriptor(name: str, orch: Node) -> OrchestrationDescriptor: + """Build an :class:`OrchestrationDescriptor` for a single orchestration. + + Dispatches by runtime type: + + * :class:`strands.Agent` (delegate) → ``kind="delegate"``, empty topology + * :class:`strands.multiagent.Swarm` → ``kind="swarm"``, swarm topology + * :class:`strands.multiagent.graph.Graph` → ``kind="graph"``, graph topology + * any other type → ``kind="unknown"``, empty topology + """ + if isinstance(orch, Agent): + return OrchestrationDescriptor( + name=name, + kind="delegate", + session_manager=_descriptor_or_none(orch._session_manager), + ) + + if isinstance(orch, Swarm): + nodes, edges, entry_id = _swarm_topology(orch) + return OrchestrationDescriptor( + name=name, + kind="swarm", + session_manager=_descriptor_or_none(getattr(orch, "session_manager", None)), + nodes=nodes, + edges=edges, + entry_node_id=entry_id, + ) + + if isinstance(orch, Graph): + nodes, edges, entry_id = _graph_topology(orch) + return OrchestrationDescriptor( + name=name, + kind="graph", + session_manager=_descriptor_or_none(getattr(orch, "session_manager", None)), + nodes=nodes, + edges=edges, + entry_node_id=entry_id, + ) + + return OrchestrationDescriptor( + name=name, + kind="unknown", + session_manager=_descriptor_or_none(getattr(orch, "session_manager", None)), + ) + + +# ── Entry resolution ───────────────────────────────────────────────────────── + + +def _resolve_entry( + entry: Node, + agents: dict[str, Agent], + orchestrators: dict[str, Node], +) -> EntryDescriptor: + """Reverse-lookup *entry* in *agents* then *orchestrators* by object identity. + + Raises: + ValueError: If *entry* is not found in either dict. + """ + for name, agent in agents.items(): + if agent is entry: + return EntryDescriptor(name=name, kind="agent") + for name, orch in orchestrators.items(): + if orch is entry: + return EntryDescriptor(name=name, kind="orchestration") + raise ValueError( + f"entry_type=<{type(entry).__name__}> | entry node not found in agents or orchestrators" + ) + + +# ── Public API ─────────────────────────────────────────────────────────────── + + +def build_manifest( + agents: dict[str, Agent], + orchestrators: dict[str, Node], + entry: Node, +) -> SessionManifest: + """Build a :class:`SessionManifest` from resolved runtime objects. + + Pure function: no I/O, no network calls, no mutation of inputs. + + Args: + agents: Resolved agents keyed by name. + orchestrators: Resolved orchestrations keyed by name. + entry: The entry point node (must be one of the values in *agents* or + *orchestrators* by object identity). + + Returns: + A complete :class:`SessionManifest` with all fields populated. + + Raises: + ValueError: If *entry* cannot be resolved by object identity. + """ + return SessionManifest( + agents=[_agent_descriptor(name, agent) for name, agent in agents.items()], + orchestrations=[ + _orchestration_descriptor(name, orch) for name, orch in orchestrators.items() + ], + entry=_resolve_entry(entry, agents, orchestrators), + ) + + +def first_session_id(manifest: SessionManifest) -> str | None: + """Return the first non-None ``session_id`` found in the manifest. + + Iterates ``manifest.agents`` first, then ``manifest.orchestrations``, + returning the first ``session_manager.session_id`` that is set. Used by + :meth:`ResolvedConfig.wire_event_queue` to determine the effective session + id for the SESSION_END event payload. + + Args: + manifest: The session manifest. + + Returns: + The first session id found, or ``None`` when no descriptor in the + manifest has a session manager set. + """ + for descriptor in (*manifest.agents, *manifest.orchestrations): + if descriptor.session_manager is not None: + return descriptor.session_manager.session_id + return None diff --git a/src/strands_compose/renderers/ansi.py b/src/strands_compose/renderers/ansi.py index 4959779..2eaf340 100644 --- a/src/strands_compose/renderers/ansi.py +++ b/src/strands_compose/renderers/ansi.py @@ -26,7 +26,7 @@ from collections.abc import Callable from typing import Any -from ..types import EventType +from ..types import EventType, SessionManifest from ..wire import StreamEvent from .base import EventRenderer @@ -82,6 +82,8 @@ def __init__( self._separator_width = shutil.get_terminal_size((70, 24)).columns self._handlers: dict[str, Callable[[StreamEvent], None]] = { + EventType.SESSION_START: self._handle_session_start, + EventType.SESSION_END: self._handle_session_end, EventType.TOKEN: self._handle_token, EventType.AGENT_START: self._handle_agent_start, EventType.TOOL_START: self._handle_tool_start, @@ -111,6 +113,29 @@ def flush(self) -> None: # noqa: D102 # -- Per-event-type handlers ------------------------------------------- + def _handle_session_start(self, event: StreamEvent) -> None: + self._break() + manifest = SessionManifest.model_validate(event.data) + agent_names = ", ".join(a.name for a in manifest.agents) or "—" + orch_names = ", ".join(o.name for o in manifest.orchestrations) or "—" + self._out.write(self._separator(manifest.entry.name, "SESSION START", color=self._cyan)) + self._out.write( + f" {self._dim}entry: {manifest.entry.name} ({manifest.entry.kind})\n" + f" agents: {agent_names}\n" + ) + if manifest.orchestrations: + self._out.write(f" orchestrations: {orch_names}\n") + self._out.write(self._reset) + self._out.flush() + + def _handle_session_end(self, event: StreamEvent) -> None: + self._break() + session_id = event.data.get("session_id") + sid_str = session_id if session_id else "—" + self._out.write(self._separator(event.agent_name, "SESSION END", color=self._cyan)) + self._out.write(f" {self._dim}session_id: {sid_str}{self._reset}\n") + self._out.flush() + def _handle_token(self, event: StreamEvent) -> None: self._ensure_mode(event.agent_name, "responding") text = event.data.get("text", "") diff --git a/src/strands_compose/types.py b/src/strands_compose/types.py index 504babb..c179ac7 100644 --- a/src/strands_compose/types.py +++ b/src/strands_compose/types.py @@ -1,8 +1,14 @@ """Shared types for the core package. -Centralises the ``Node`` union so it is defined exactly once and -imported everywhere else, rather than being duplicated in -``orchestration`` and ``config/resolvers``. +This module is the single canonical home for cross-package types: + +- ``Node`` — alias for ``Agent | MultiAgentBase``, the kinds of nodes that + participate in a session. +- ``EventType`` and ``StreamEvent`` — the typed-event protocol used by the + event queue and external consumers. +- The ``SessionManifest`` family of Pydantic models — the schema describing a + wired session at invocation time. The manifest schema is intentionally + decoupled from the YAML config schema in ``config.schema``. """ from __future__ import annotations @@ -10,8 +16,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import StrEnum -from typing import Any +from typing import Annotated, Any, Literal +from pydantic import BaseModel, Field from strands import Agent from strands.multiagent.base import MultiAgentBase @@ -44,19 +51,27 @@ class EventType(StrEnum): MULTIAGENT_START = "multiagent_start" MULTIAGENT_COMPLETE = "multiagent_complete" + # Session-level events + SESSION_START = "session_start" + SESSION_END = "session_end" + @dataclass class StreamEvent: """A typed event from agent or multi-agent execution. - Produced by :class:`~strands_compose.hooks.EventPublisher` for - all agent activity: ``TOKEN``, ``REASONING``, ``TOOL_START``, ``TOOL_END``, - ``INTERRUPT``, ``COMPLETE``, ``ERROR``, ``NODE_START``, ``NODE_STOP``, - ``HANDOFF``, ``MULTIAGENT_COMPLETE``. + Per-agent activity (``AGENT_START``, ``TOKEN``, ``REASONING``, + ``TOOL_START``, ``TOOL_END``, ``INTERRUPT``, ``COMPLETE``, ``ERROR``, + ``NODE_START``, ``NODE_STOP``, ``HANDOFF``, ``MULTIAGENT_START``, + ``MULTIAGENT_COMPLETE``) is produced by + :class:`~strands_compose.hooks.EventPublisher`. Session-level events + (``SESSION_START``, ``SESSION_END``) are produced by the queue/wiring + layer in :mod:`strands_compose.wire`. Attributes: type: Event type identifier (one of the :class:`EventType` values). - agent_name: Name of the agent that produced this event. + agent_name: Name of the agent or session entry point that produced + this event. timestamp: When the event occurred. data: Event-specific payload. """ @@ -109,3 +124,175 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: """Hash based on type and agent_name (data is unhashable).""" return hash((self.type, self.agent_name)) + + +# ── Session Manifest Models ────────────────────────────────────────────────── + + +class NodeRef(BaseModel): + """Reference to a node in an orchestration topology. + + Attributes: + id: The node identifier (node_id for swarm/graph nodes). + kind: The node kind ("agent" or "orchestration"). + """ + + id: str + kind: str + + +class EdgeRef(BaseModel): + """Reference to a directed edge in a graph orchestration. + + Attributes: + from_id: The source node identifier. + to_id: The target node identifier. + """ + + from_id: str + to_id: str + + +class ModelDescriptor(BaseModel): + """Descriptor for an agent's model and provider. + + Attributes: + model_id: The model identifier (e.g. "us.anthropic.claude-sonnet-4-6"), + or None if not available. + provider: The fully-qualified class name of the model provider. + """ + + model_id: str | None + provider: str + + +class FileProviderDescriptor(BaseModel): + """Session manager descriptor for file-based storage. + + Attributes: + provider: Literal "file". + session_id: The session identifier. + storage_dir: The filesystem directory where sessions are stored. + """ + + provider: Literal["file"] + session_id: str + storage_dir: str + + +class S3ProviderDescriptor(BaseModel): + """Session manager descriptor for S3-based storage. + + Attributes: + provider: Literal "s3". + session_id: The session identifier. + bucket: The S3 bucket name. + prefix: The S3 key prefix (empty string if no prefix). + """ + + provider: Literal["s3"] + session_id: str + bucket: str + prefix: str + + +class AgentCoreProviderDescriptor(BaseModel): + """Session manager descriptor for AgentCore Memory storage. + + Attributes: + provider: Literal "agentcore". + session_id: The session identifier. + memory_id: The AgentCore memory identifier. + actor_id: The AgentCore actor identifier. + """ + + provider: Literal["agentcore"] + session_id: str + memory_id: str + actor_id: str + + +class CustomProviderDescriptor(BaseModel): + """Session manager descriptor for custom session manager implementations. + + Attributes: + provider: Literal "custom". + session_id: The session identifier, or None if not available. + class_name: The fully-qualified class name of the session manager. + """ + + provider: Literal["custom"] + session_id: str | None + class_name: str + + +SessionManagerDescriptor = Annotated[ + FileProviderDescriptor + | S3ProviderDescriptor + | AgentCoreProviderDescriptor + | CustomProviderDescriptor, + Field(discriminator="provider"), +] +"""Discriminated union of session manager descriptors by provider type.""" + + +class AgentDescriptor(BaseModel): + """Descriptor for a configured agent in the session. + + Attributes: + name: The agent's configured name. + description: The agent's description, or None. + model: The agent's model descriptor. + session_manager: The agent's session manager descriptor, or None. + """ + + name: str + description: str | None + model: ModelDescriptor + session_manager: SessionManagerDescriptor | None + + +class OrchestrationDescriptor(BaseModel): + """Descriptor for a configured orchestration in the session. + + Attributes: + name: The orchestration's configured name. + kind: The orchestration kind ("delegate", "swarm", "graph", or "unknown"). + session_manager: The orchestration's session manager descriptor, or None. + nodes: List of nodes in the orchestration topology. + edges: List of edges in the orchestration topology (None for swarm/delegate). + entry_node_id: The entry node identifier(s), or None. + """ + + name: str + kind: str + session_manager: SessionManagerDescriptor | None + nodes: list[NodeRef] = Field(default_factory=list) + edges: list[EdgeRef] | None = None + entry_node_id: str | None = None + + +class EntryDescriptor(BaseModel): + """Descriptor identifying the session entry point. + + Attributes: + name: The entry point's configured name. + kind: The entry point kind ("agent" or "orchestration"). + """ + + name: str + kind: str + + +class SessionManifest(BaseModel): + """Manifest describing the wired session topology and configuration. + + Attributes: + agents: List of agent descriptors. + orchestrations: List of orchestration descriptors. + entry: The entry point descriptor. + """ + + agents: list[AgentDescriptor] = Field(default_factory=list) + orchestrations: list[OrchestrationDescriptor] = Field(default_factory=list) + entry: EntryDescriptor diff --git a/src/strands_compose/wire.py b/src/strands_compose/wire.py index 7e2385d..38175ed 100644 --- a/src/strands_compose/wire.py +++ b/src/strands_compose/wire.py @@ -4,19 +4,17 @@ :class:`~strands_compose.hooks.EventPublisher` for all agent activity. :class:`EventQueue` is a thin async queue wrapper that hides the sentinel -pattern from callers. :func:`make_event_queue` attaches -:class:`~strands_compose.hooks.EventPublisher` hooks to every agent -so all events (TOKEN, REASONING, TOOL_START, TOOL_END, INTERRUPT, COMPLETE, -and — for Swarm/Graph — NODE_START, NODE_STOP, HANDOFF, MULTIAGENT_COMPLETE) -flow into the shared queue. - -Hooks are wired **once per session**. Between requests on the same session, -call :meth:`EventQueue.flush` to discard stale events. - -Key Features: - - Async queue with hidden end-of-stream sentinel pattern - - Thread-safe event injection for cross-thread publishing - - Automatic EventPublisher wiring for agents and orchestrators +pattern from callers and brackets every invocation with a SESSION_START +event (carrying the session manifest) and a SESSION_END event. + +:func:`make_event_queue` attaches :class:`~strands_compose.hooks.EventPublisher` +hooks to every agent so all per-agent events (TOKEN, REASONING, TOOL_START, +TOOL_END, INTERRUPT, COMPLETE, and — for Swarm/Graph — NODE_START, NODE_STOP, +HANDOFF, MULTIAGENT_COMPLETE) flow into the shared queue. + +Hooks are wired **once per session**. Between requests on the same session, +call :meth:`EventQueue.flush` to discard stale events and reset the +SESSION_START / SESSION_END guards. """ from __future__ import annotations @@ -30,7 +28,7 @@ from strands.multiagent.graph import Graph from .hooks import EventPublisher -from .types import StreamEvent +from .types import EventType, SessionManifest, StreamEvent if TYPE_CHECKING: from .types import Node @@ -46,15 +44,25 @@ class EventQueue: - """Async event queue with a hidden end-of-stream sentinel.""" - - def __init__(self, queue: asyncio.Queue) -> None: + """Async event queue with hidden end-of-stream sentinel and session lifecycle.""" + + def __init__( + self, + queue: asyncio.Queue, + *, + entry_name: str | None = None, + session_id: str | None = None, + ) -> None: """Initialize the EventQueue. Callers consume events via :meth:`get` (which returns ``None`` when the stream is closed) and signal completion via :meth:`close`. The sentinel is an implementation detail — user code never sees or owns it. + *entry_name* and *session_id* parameterise the SESSION_START and + SESSION_END events emitted by :meth:`emit_session_start` and + :meth:`close` respectively. + Example:: events = make_event_queue(config.agents) @@ -73,8 +81,17 @@ async def _run(): Args: queue: The underlying asyncio.Queue to wrap. + entry_name: The configured name of the entry node. Used as + ``agent_name`` on SESSION_START and SESSION_END events. + Defaults to an empty string when not provided. + session_id: The effective session id. Included in the + SESSION_END event payload as ``data["session_id"]``. """ self._queue = queue + self._entry_name = entry_name or "" + self._session_id = session_id + self._session_start_emitted = False + self._session_end_emitted = False # -- Internal ---------------------------------------------------------- @@ -95,13 +112,40 @@ def _put(self, event: StreamEvent | object) -> None: # -- Public API -------------------------------------------------------- def flush(self) -> None: - """Discard all currently queued events. + """Discard all currently queued events and reset lifecycle guards. - Call this at the start of each request to clear any stale events - left over from a previous invocation. + Call this at the start of each request to clear any stale events left + over from a previous invocation. Resets the SESSION_START and + SESSION_END guards so the next invocation cycle can re-emit them. """ while not self._queue.empty(): self._queue.get_nowait() + self._session_start_emitted = False + self._session_end_emitted = False + + def emit_session_start(self, manifest: SessionManifest) -> None: + """Emit a SESSION_START event with the session manifest. + + Places a :class:`StreamEvent` with ``type=EventType.SESSION_START`` on + the queue. A guard prevents double-emission within the same + invocation cycle (reset by :meth:`flush`). + + Args: + manifest: The :class:`SessionManifest` describing the wired + session. Serialised via ``.model_dump()`` into the event + payload. + """ + if self._session_start_emitted: + return + self._session_start_emitted = True + self._put( + StreamEvent( + type=EventType.SESSION_START, + agent_name=self._entry_name, + data=manifest.model_dump(), + ) + ) + logger.debug("entry=<%s> | session_start emitted", self._entry_name) async def get(self) -> StreamEvent | None: """Wait for the next event. @@ -123,74 +167,106 @@ def put_event(self, event: StreamEvent) -> None: async def close(self) -> None: """Signal end-of-stream. - Places the sentinel on the queue so that the consumer loop - terminates cleanly. Typically called in a ``finally`` block after - the agent invocation finishes. + Emits a SESSION_END event before placing the sentinel on the queue. + A guard prevents double-emission within the same invocation cycle + (reset by :meth:`flush`). Subsequent ``close()`` calls are no-ops + for the SESSION_END emission but still place the sentinel — the + method remains idempotent. + + Typically called in a ``finally`` block after the agent invocation + finishes. """ + if not self._session_end_emitted: + self._session_end_emitted = True + self._put( + StreamEvent( + type=EventType.SESSION_END, + agent_name=self._entry_name, + data={"session_id": self._session_id}, + ) + ) + logger.debug( + "entry=<%s>, session_id=<%s> | session_end emitted", + self._entry_name, + self._session_id, + ) await self._queue.put(_SENTINEL) +# ── Wiring ─────────────────────────────────────────────────────────────────── + + def make_event_queue( agents: dict[str, Agent], *, orchestrators: dict[str, Node] | None = None, tool_labels: dict[str, str] | None = None, + entry_name: str | None = None, + session_id: str | None = None, ) -> EventQueue: """Attach :class:`~strands_compose.hooks.EventPublisher` hooks to agents. Every agent in *agents* receives an :class:`.EventPublisher` hook and a - matching ``callback_handler`` so that all event types flow into the - returned :class:`EventQueue`. + matching ``callback_handler`` so all per-agent event types flow into the + returned :class:`EventQueue`. Orchestrators (Swarm / Graph / delegate + Agent) in *orchestrators* also get a publisher for NODE_START, NODE_STOP, + HANDOFF, and MULTIAGENT_COMPLETE events. - Orchestrators (Swarm / Graph) in *orchestrators* also get a publisher - for NODE_START, NODE_STOP, HANDOFF, and MULTIAGENT_COMPLETE events. + This function does **not** emit SESSION_START. Callers that own a + :class:`~strands_compose.types.SessionManifest` should call + :meth:`EventQueue.emit_session_start` themselves; the common + :class:`ResolvedConfig` workflow does this for you via + :meth:`ResolvedConfig.wire_event_queue`. .. warning:: - This function **mutates** the passed-in agents and orchestrators - by adding hooks and overwriting ``callback_handler``. Call it - only once per set of agents. For the common ``ResolvedConfig`` - workflow, prefer :meth:`ResolvedConfig.wire_event_queue` which - makes the mutation explicit. + This function **mutates** the passed-in agents and orchestrators by + adding hooks and overwriting ``callback_handler``. Call it only once + per set of agents. Args: agents: Agents to wire, keyed by name. orchestrators: Built orchestrations keyed by name. - tool_labels: Tool name -> display label mapping forwarded to each + tool_labels: Tool name → display label mapping forwarded to each :class:`.EventPublisher`. Defaults to ``{name: "Delegating work to agent: "}`` for every agent. + entry_name: The configured name of the entry node. Stored on the + EventQueue and used as ``agent_name`` on SESSION_START / + SESSION_END events. + session_id: The effective session id. Stored on the EventQueue and + included in the SESSION_END event payload. Returns: A ready-to-use :class:`EventQueue`. """ - event_queue = EventQueue(asyncio.Queue(maxsize=10000)) + event_queue = EventQueue( + asyncio.Queue(maxsize=10000), + entry_name=entry_name, + session_id=session_id, + ) labels = { **{name: f"Delegating work to agent: {name.title()}" for name in agents}, **(tool_labels or {}), } - # Wire every agent with a publisher. for name, agent in agents.items(): pub = EventPublisher(callback=event_queue._put, agent_name=name, tool_labels=labels) agent.hooks.add_hook(pub) agent.callback_handler = pub.as_callback_handler() logger.debug("agent=<%s> | wired EventPublisher", name) - # Wire orchestrators (Swarm / Graph instances). for orch_name, orch in (orchestrators or {}).items(): - if isinstance(orch, (Swarm, Graph, Agent)): - orch_pub = EventPublisher( - callback=event_queue._put, - agent_name=orch_name, - tool_labels=labels, - ) - orch.hooks.add_hook(orch_pub) - - # If orch is an Agent, it needs the callback_handler set like any other agent. - if isinstance(orch, Agent): - orch.callback_handler = orch_pub.as_callback_handler() - - logger.debug("orchestrator=<%s> | wired EventPublisher", orch_name) + if not isinstance(orch, (Swarm, Graph, Agent)): + continue + orch_pub = EventPublisher( + callback=event_queue._put, + agent_name=orch_name, + tool_labels=labels, + ) + orch.hooks.add_hook(orch_pub) + if isinstance(orch, Agent): + orch.callback_handler = orch_pub.as_callback_handler() + logger.debug("orchestrator=<%s> | wired EventPublisher", orch_name) return event_queue diff --git a/tasks/test.just b/tasks/test.just index 7a26441..5985027 100644 --- a/tasks/test.just +++ b/tasks/test.just @@ -4,7 +4,7 @@ test: test-coverage # check code coverage [group('test')] -test-coverage numprocesses="auto" cov_fail_under="70": +test-coverage numprocesses="auto" cov_fail_under="90": uv run python -m pytest --numprocesses={{numprocesses}} --cov={{SOURCES}} --cov-fail-under={{cov_fail_under}} {{TESTS}} # run mutation testing (requires: pip install mutmut) diff --git a/tests/integration/test_session_lifecycle_events.py b/tests/integration/test_session_lifecycle_events.py new file mode 100644 index 0000000..7813048 --- /dev/null +++ b/tests/integration/test_session_lifecycle_events.py @@ -0,0 +1,344 @@ +"""Integration tests for session lifecycle events (SESSION_START and SESSION_END). + +Tests verify that SESSION_START is emitted as the first event and SESSION_END +as the last typed event before the stream sentinel, across various +orchestration topologies and invocation cycles. +""" + +from __future__ import annotations + +import pytest + +from strands_compose.config import load +from strands_compose.types import EventType, StreamEvent + + +@pytest.mark.integration +class TestSessionLifecycleEventsSingleAgent: + """Session lifecycle events with a single agent.""" + + @pytest.mark.asyncio + async def test_session_start_first_session_end_last_single_agent(self, fixture_path): + """Verify SESSION_START is first event and SESSION_END is last for single agent.""" + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + events: list[StreamEvent] = [] + + try: + # Simulate a simple invocation by just closing the queue + # (no actual agent call, just testing event ordering) + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + assert len(events) >= 2, "Expected at least SESSION_START and SESSION_END" + assert events[0].type == EventType.SESSION_START, "First event should be SESSION_START" + assert events[0].agent_name == "greeter", "SESSION_START agent_name should be entry point" + + assert events[-1].type == EventType.SESSION_END, "Last event should be SESSION_END" + assert events[-1].agent_name == "greeter", "SESSION_END agent_name should be entry point" + assert events[-1].data == {"session_id": None}, "SESSION_END data should have session_id" + + @pytest.mark.asyncio + async def test_session_start_payload_contains_manifest(self, fixture_path): + """Verify SESSION_START payload contains valid manifest.""" + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + try: + pass + finally: + await eq.close() + + event = await eq.get() + assert event is not None + assert event.type == EventType.SESSION_START + + manifest = event.data + assert isinstance(manifest, dict) + assert "agents" in manifest + assert "orchestrations" in manifest + assert "entry" in manifest + + assert isinstance(manifest["agents"], list) + assert len(manifest["agents"]) > 0 + agent = manifest["agents"][0] + assert "name" in agent + assert "description" in agent + assert "model" in agent + assert "session_manager" in agent + + entry = manifest["entry"] + assert "name" in entry + assert "kind" in entry + assert entry["kind"] in ("agent", "orchestration") + + +@pytest.mark.integration +class TestSessionLifecycleEventsMultipleAgents: + """Session lifecycle events with multiple agents.""" + + @pytest.mark.asyncio + async def test_session_start_session_end_multiple_agents(self, fixture_path): + """Verify SESSION_START and SESSION_END with multiple agents.""" + resolved = load(fixture_path("multi_agent_delegate.yaml")) + eq = resolved.wire_event_queue() + + events: list[StreamEvent] = [] + try: + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + assert events[0].type == EventType.SESSION_START + assert events[0].agent_name == "coordinator" + + assert events[-1].type == EventType.SESSION_END + assert events[-1].agent_name == "coordinator" + + manifest = events[0].data + agent_names = {agent["name"] for agent in manifest["agents"]} + assert "researcher" in agent_names + assert "writer" in agent_names + + +@pytest.mark.integration +class TestSessionLifecycleEventsSwarmOrchestration: + """Session lifecycle events with swarm orchestration.""" + + @pytest.mark.asyncio + async def test_session_start_session_end_swarm(self, fixture_path): + """Verify SESSION_START and SESSION_END with swarm orchestration.""" + resolved = load(fixture_path("swarm.yaml")) + eq = resolved.wire_event_queue() + + events: list[StreamEvent] = [] + try: + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + assert events[0].type == EventType.SESSION_START + assert events[0].agent_name == "team" + + assert events[-1].type == EventType.SESSION_END + assert events[-1].agent_name == "team" + + manifest = events[0].data + assert len(manifest["orchestrations"]) > 0 + swarm = manifest["orchestrations"][0] + assert swarm["kind"] == "swarm" + assert "nodes" in swarm + assert len(swarm["nodes"]) > 0 + + +@pytest.mark.integration +class TestSessionLifecycleEventsGraphOrchestration: + """Session lifecycle events with graph orchestration.""" + + @pytest.mark.asyncio + async def test_session_start_session_end_graph(self, fixture_path): + """Verify SESSION_START and SESSION_END with graph orchestration.""" + resolved = load(fixture_path("graph.yaml")) + eq = resolved.wire_event_queue() + + events: list[StreamEvent] = [] + try: + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + assert events[0].type == EventType.SESSION_START + assert events[0].agent_name == "pipeline" + + assert events[-1].type == EventType.SESSION_END + assert events[-1].agent_name == "pipeline" + + manifest = events[0].data + assert len(manifest["orchestrations"]) > 0 + graph = manifest["orchestrations"][0] + assert graph["kind"] == "graph" + assert "nodes" in graph + assert "edges" in graph + assert isinstance(graph["edges"], list) + + +@pytest.mark.integration +class TestSessionLifecycleEventsMultipleInvocations: + """Session lifecycle events across multiple invocation cycles.""" + + @pytest.mark.asyncio + async def test_multiple_invocations_separate_session_end(self, fixture_path): + """Verify separate SESSION_END for each invocation cycle. + + Note: SESSION_START is only emitted once when wire_event_queue() is called. + For multiple invocations, the caller must manually call emit_session_start() + after flush() if they want a new SESSION_START for the next invocation. + """ + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + events1: list[StreamEvent] = [] + try: + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events1.append(event) + + assert events1[0].type == EventType.SESSION_START + assert events1[-1].type == EventType.SESSION_END + + eq.flush() + events2: list[StreamEvent] = [] + try: + pass + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events2.append(event) + + assert events2[-1].type == EventType.SESSION_END + + assert events1[-1] is not events2[-1] + + +@pytest.mark.integration +class TestSessionLifecycleEventsExceptionHandling: + """Session lifecycle events when exceptions occur during invocation.""" + + @pytest.mark.asyncio + async def test_session_end_emitted_on_exception(self, fixture_path): + """Verify SESSION_END is emitted even when exception occurs.""" + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + events: list[StreamEvent] = [] + exception_raised = False + + try: + raise RuntimeError("Simulated invocation error") + except RuntimeError: + exception_raised = True + finally: + await eq.close() + + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + assert exception_raised + + assert len(events) >= 2 + assert events[0].type == EventType.SESSION_START + assert events[-1].type == EventType.SESSION_END + + +@pytest.mark.integration +class TestSessionLifecycleEventsIdempotency: + """Session lifecycle events idempotency and guard behavior.""" + + @pytest.mark.asyncio + async def test_close_called_multiple_times_session_end_once(self, fixture_path): + """Verify SESSION_END is emitted only once even if close() called multiple times.""" + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + try: + pass + finally: + await eq.close() + await eq.close() + await eq.close() + + events: list[StreamEvent] = [] + while True: + event = await eq.get() + if event is None: + break + events.append(event) + + session_end_count = sum(1 for e in events if e.type == EventType.SESSION_END) + assert session_end_count == 1, "SESSION_END should be emitted exactly once" + + @pytest.mark.asyncio + async def test_flush_resets_guards_allows_reemission(self, fixture_path): + """Verify flush() resets guards and allows re-emission in next cycle. + + Note: SESSION_START is only emitted once when wire_event_queue() is called. + For multiple invocations, the caller must manually call emit_session_start() + after flush() if they want a new SESSION_START for the next invocation. + """ + resolved = load(fixture_path("minimal.yaml")) + eq = resolved.wire_event_queue() + + try: + pass + finally: + await eq.close() + + events1: list[StreamEvent] = [] + while True: + event = await eq.get() + if event is None: + break + events1.append(event) + + session_start_count_1 = sum(1 for e in events1 if e.type == EventType.SESSION_START) + session_end_count_1 = sum(1 for e in events1 if e.type == EventType.SESSION_END) + + eq.flush() + try: + pass + finally: + await eq.close() + + events2: list[StreamEvent] = [] + while True: + event = await eq.get() + if event is None: + break + events2.append(event) + + session_end_count_2 = sum(1 for e in events2 if e.type == EventType.SESSION_END) + + assert session_start_count_1 == 1 + assert session_end_count_1 == 1 + + # SESSION_START is not re-emitted unless emit_session_start() is called manually + assert session_end_count_2 == 1 diff --git a/tests/unit/config/resolvers/test_wire_event_queue.py b/tests/unit/config/resolvers/test_wire_event_queue.py index f197621..3129d59 100644 --- a/tests/unit/config/resolvers/test_wire_event_queue.py +++ b/tests/unit/config/resolvers/test_wire_event_queue.py @@ -8,13 +8,24 @@ from strands_compose.wire import EventQueue +def _mock_manifest(entry_name: str = "a") -> MagicMock: + """Build a mock SessionManifest with the minimum surface used by wire_event_queue.""" + manifest = MagicMock() + manifest.entry.name = entry_name + manifest.agents = [] + manifest.orchestrations = [] + return manifest + + class TestWireEventQueue: """Unit tests for ResolvedConfig.wire_event_queue().""" + @patch("strands_compose.config.resolvers.config.build_manifest") @patch("strands_compose.config.resolvers.config.make_event_queue") - def test_returns_event_queue(self, mock_make_eq): + def test_returns_event_queue(self, mock_make_eq, mock_build_manifest): mock_eq = MagicMock(spec=EventQueue) mock_make_eq.return_value = mock_eq + mock_build_manifest.return_value = _mock_manifest() agent = MagicMock() agent.agent_id = "a" @@ -26,10 +37,14 @@ def test_returns_event_queue(self, mock_make_eq): assert result is mock_eq mock_make_eq.assert_called_once() + mock_build_manifest.assert_called_once() + mock_eq.emit_session_start.assert_called_once() + @patch("strands_compose.config.resolvers.config.build_manifest") @patch("strands_compose.config.resolvers.config.make_event_queue") - def test_passes_agents_and_orchestrators(self, mock_make_eq): + def test_passes_agents_and_orchestrators(self, mock_make_eq, mock_build_manifest): mock_make_eq.return_value = MagicMock(spec=EventQueue) + mock_build_manifest.return_value = _mock_manifest() agent = MagicMock() agent.agent_id = "a" @@ -48,9 +63,11 @@ def test_passes_agents_and_orchestrators(self, mock_make_eq): assert call_args[0][0] == {"a": agent} assert call_args[1]["orchestrators"] == {"o": orch} + @patch("strands_compose.config.resolvers.config.build_manifest") @patch("strands_compose.config.resolvers.config.make_event_queue") - def test_forwards_tool_labels(self, mock_make_eq): + def test_forwards_tool_labels(self, mock_make_eq, mock_build_manifest): mock_make_eq.return_value = MagicMock(spec=EventQueue) + mock_build_manifest.return_value = _mock_manifest() agent = MagicMock() agent.agent_id = "a" @@ -63,9 +80,11 @@ def test_forwards_tool_labels(self, mock_make_eq): assert mock_make_eq.call_args[1]["tool_labels"] == labels + @patch("strands_compose.config.resolvers.config.build_manifest") @patch("strands_compose.config.resolvers.config.make_event_queue") - def test_none_tool_labels_by_default(self, mock_make_eq): + def test_none_tool_labels_by_default(self, mock_make_eq, mock_build_manifest): mock_make_eq.return_value = MagicMock(spec=EventQueue) + mock_build_manifest.return_value = _mock_manifest() agent = MagicMock() agent.agent_id = "a" diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 85ced06..5b7d407 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -12,6 +12,7 @@ import pytest from strands_compose.mcp.lifecycle import MCPLifecycle +from strands_compose.types import EventType from strands_compose.wire import EventQueue, StreamEvent # --------------------------------------------------------------------------- @@ -61,7 +62,10 @@ def _producer(start: int, count: int) -> None: break received.append(event) - assert len(received) == num_events + # Should have 50 producer events + 1 SESSION_END event + assert len(received) == num_events + 1 + # Last event should be SESSION_END + assert received[-1].type == EventType.SESSION_END @pytest.mark.asyncio async def test_put_event_thread_safe(self) -> None: @@ -91,10 +95,15 @@ async def test_queue_full_drops_event(self) -> None: @pytest.mark.asyncio async def test_close_then_get_returns_none(self) -> None: - """Closing the queue causes get to return None.""" + """Closing the queue causes get to return None after SESSION_END.""" queue = asyncio.Queue() eq = EventQueue(queue) await eq.close() + # First get() returns SESSION_END event + session_end = await eq.get() + assert session_end is not None + assert session_end.type == EventType.SESSION_END + # Second get() returns None (sentinel) result = await eq.get() assert result is None diff --git a/tests/unit/test_event_queue.py b/tests/unit/test_event_queue.py index 7c5a06a..3c3c05a 100644 --- a/tests/unit/test_event_queue.py +++ b/tests/unit/test_event_queue.py @@ -9,6 +9,7 @@ from strands.multiagent import Swarm from strands_compose.hooks import EventPublisher +from strands_compose.types import EventType from strands_compose.wire import EventQueue, StreamEvent, make_event_queue # --------------------------------------------------------------------------- @@ -32,6 +33,11 @@ async def _run(): queue = asyncio.Queue() eq = EventQueue(queue) await eq.close() + # First get() returns SESSION_END event + session_end = await eq.get() + assert session_end is not None + assert session_end.type == EventType.SESSION_END + # Second get() returns None (sentinel) return await eq.get() assert asyncio.run(_run()) is None diff --git a/tests/unit/test_manifest.py b/tests/unit/test_manifest.py new file mode 100644 index 0000000..cf571dd --- /dev/null +++ b/tests/unit/test_manifest.py @@ -0,0 +1,616 @@ +"""Tests for strands_compose.manifest — pure manifest builders.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest +from strands import Agent +from strands.multiagent import Swarm +from strands.multiagent.graph import Graph +from strands.session import FileSessionManager, S3SessionManager + +from strands_compose.manifest import ( + build_manifest, + build_session_manager_descriptor, + first_session_id, +) +from strands_compose.types import ( + AgentCoreProviderDescriptor, + AgentDescriptor, + CustomProviderDescriptor, + EntryDescriptor, + FileProviderDescriptor, + ModelDescriptor, + S3ProviderDescriptor, + SessionManagerDescriptor, + SessionManifest, +) + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _mock_agent( + description: str | None = None, + model_id: str | None = None, + session_manager: object | None = None, +) -> Agent: + """Create a mock Agent suitable for manifest building.""" + agent = Mock(spec=Agent) + agent.description = description + agent.model = Mock() + config = {"model_id": model_id} if model_id else {} + agent.model.get_config.return_value = config + agent.model.__class__.__module__ = "strands.models" + agent.model.__class__.__qualname__ = "TestModel" + agent._session_manager = session_manager + return agent # type: ignore[return-value] + + +# ── build_session_manager_descriptor ───────────────────────────────────────── + + +class TestBuildSessionManagerDescriptor: + """Tests for build_session_manager_descriptor.""" + + def test_file_session_manager_descriptor(self, tmp_path): + """FileSessionManager → FileProviderDescriptor.""" + storage_dir = str(tmp_path / "sessions") + manager = FileSessionManager(session_id="sess-123", storage_dir=storage_dir) + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, FileProviderDescriptor) + assert descriptor.provider == "file" + assert descriptor.session_id == "sess-123" + assert descriptor.storage_dir == storage_dir + + def test_s3_session_manager_descriptor(self): + """S3SessionManager → S3ProviderDescriptor.""" + manager = Mock(spec=S3SessionManager) + manager.session_id = "sess-456" + manager.bucket = "my-bucket" + manager.prefix = "sessions/" + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, S3ProviderDescriptor) + assert descriptor.provider == "s3" + assert descriptor.session_id == "sess-456" + assert descriptor.bucket == "my-bucket" + assert descriptor.prefix == "sessions/" + + def test_s3_session_manager_descriptor_empty_prefix(self): + """S3SessionManager with empty prefix.""" + manager = Mock(spec=S3SessionManager) + manager.session_id = "sess-789" + manager.bucket = "bucket" + manager.prefix = "" + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, S3ProviderDescriptor) + assert descriptor.prefix == "" + + def test_agentcore_duck_typed_descriptor(self): + """Duck-typed AgentCore manager → AgentCoreProviderDescriptor.""" + manager = Mock(spec=[]) + manager.config = Mock(spec=["memory_id", "actor_id", "session_id"]) + manager.config.memory_id = "mem-123" + manager.config.actor_id = "actor-456" + manager.config.session_id = "sess-789" + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, AgentCoreProviderDescriptor) + assert descriptor.provider == "agentcore" + assert descriptor.session_id == "sess-789" + assert descriptor.memory_id == "mem-123" + assert descriptor.actor_id == "actor-456" + + def test_custom_session_manager_descriptor(self): + """Unknown manager type → CustomProviderDescriptor.""" + manager = Mock(spec=[]) + manager.session_id = "sess-custom" + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, CustomProviderDescriptor) + assert descriptor.provider == "custom" + assert descriptor.session_id == "sess-custom" + assert "Mock" in descriptor.class_name + + def test_custom_session_manager_descriptor_no_session_id(self): + """Custom manager without session_id → session_id is None.""" + manager = Mock(spec=[]) + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, CustomProviderDescriptor) + assert descriptor.provider == "custom" + assert descriptor.session_id is None + + def test_custom_descriptor_class_name_fully_qualified(self): + """CustomProviderDescriptor.class_name is fully-qualified.""" + manager = Mock(spec=[]) + manager.session_id = None + + descriptor = build_session_manager_descriptor(manager) + + assert isinstance(descriptor, CustomProviderDescriptor) + assert "." in descriptor.class_name + assert descriptor.class_name.startswith("unittest.mock") + + +# ── build_manifest ─────────────────────────────────────────────────────────── + + +class TestBuildManifest: + """Tests for build_manifest.""" + + def test_manifest_with_single_agent(self): + """Manifest with one agent.""" + agent = _mock_agent(description="Test agent", model_id="gpt-4") + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert len(manifest.agents) == 1 + assert manifest.agents[0].name == "agent1" + assert manifest.agents[0].description == "Test agent" + assert manifest.agents[0].model.model_id == "gpt-4" + assert "TestModel" in manifest.agents[0].model.provider + assert manifest.agents[0].session_manager is None + + def test_manifest_agent_description_none(self): + """Agent with None description.""" + agent = _mock_agent(description=None) + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.agents[0].description is None + + def test_manifest_agent_model_id_from_dict_config(self): + """Extract model_id from dict config.""" + agent = _mock_agent(model_id="claude-3") + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.agents[0].model.model_id == "claude-3" + + def test_manifest_agent_model_id_from_object_config(self): + """Extract model_id from object config via getattr.""" + agent = Mock(spec=Agent) + agent.description = None + config_obj = Mock() + config_obj.model_id = "custom-model" + agent.model = Mock() + agent.model.get_config.return_value = config_obj + agent.model.__class__.__module__ = "custom" + agent.model.__class__.__qualname__ = "CustomModel" + agent._session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.agents[0].model.model_id == "custom-model" + + def test_manifest_agent_model_id_none_when_absent(self): + """model_id is None when not in config.""" + agent = _mock_agent() # no model_id + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.agents[0].model.model_id is None + + def test_manifest_agent_session_manager_none(self): + """session_manager is None when agent has no session manager.""" + agent = _mock_agent() + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.agents[0].session_manager is None + + def test_manifest_entry_not_found_raises_value_error(self): + """ValueError when entry not found in agents or orchestrators.""" + agent = _mock_agent() + other_agent = Mock(spec=Agent) + + with pytest.raises(ValueError, match="entry node not found"): + build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=other_agent, + ) + + def test_manifest_orchestration_kind_delegate(self): + """Agent orchestration → kind='delegate'.""" + agent = _mock_agent() + delegate = Mock(spec=Agent) + delegate._session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"delegate1": delegate}, + entry=delegate, + ) + + assert len(manifest.orchestrations) == 1 + assert manifest.orchestrations[0].kind == "delegate" + assert manifest.orchestrations[0].nodes == [] + assert manifest.orchestrations[0].edges is None + assert manifest.orchestrations[0].entry_node_id is None + + def test_manifest_orchestration_kind_swarm(self): + """Swarm orchestration → kind='swarm'.""" + agent = _mock_agent() + swarm = Mock(spec=Swarm) + swarm.nodes = {} + swarm.entry_point = None + swarm.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"swarm1": swarm}, + entry=swarm, + ) + + assert len(manifest.orchestrations) == 1 + assert manifest.orchestrations[0].kind == "swarm" + + def test_manifest_orchestration_kind_graph(self): + """Graph orchestration → kind='graph'.""" + agent = _mock_agent() + graph = Mock(spec=Graph) + graph.nodes = {} + graph.edges = set() + graph.entry_points = set() + graph.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"graph1": graph}, + entry=graph, + ) + + assert len(manifest.orchestrations) == 1 + assert manifest.orchestrations[0].kind == "graph" + + def test_manifest_orchestration_kind_unknown(self): + """Unknown orchestration type → kind='unknown'.""" + agent = _mock_agent() + unknown = Mock() # Not Agent, Swarm, or Graph + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"unknown1": unknown}, + entry=unknown, + ) + + assert len(manifest.orchestrations) == 1 + assert manifest.orchestrations[0].kind == "unknown" + + def test_manifest_entry_descriptor_agent(self): + """Entry descriptor for agent entry.""" + agent = _mock_agent() + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={}, + entry=agent, + ) + + assert manifest.entry.name == "agent1" + assert manifest.entry.kind == "agent" + + def test_manifest_entry_descriptor_orchestration(self): + """Entry descriptor for orchestration entry.""" + agent = _mock_agent() + swarm = Mock(spec=Swarm) + swarm.nodes = {} + swarm.entry_point = None + swarm.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"swarm1": swarm}, + entry=swarm, + ) + + assert manifest.entry.name == "swarm1" + assert manifest.entry.kind == "orchestration" + + def test_manifest_insertion_order_preserved_agents(self): + """Agent descriptor order matches dict insertion order.""" + agents = {f"agent{i}": _mock_agent() for i in range(3)} + + manifest = build_manifest( + agents=agents, + orchestrators={}, + entry=agents["agent0"], + ) + + assert [d.name for d in manifest.agents] == ["agent0", "agent1", "agent2"] + + def test_manifest_insertion_order_preserved_orchestrations(self): + """Orchestration descriptor order matches dict insertion order.""" + agent = _mock_agent() + + orchestrators = {} + for i in range(3): + orch = Mock(spec=Swarm) + orch.nodes = {} + orch.entry_point = None + orch.session_manager = None + orchestrators[f"orch{i}"] = orch + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators=orchestrators, + entry=agent, + ) + + assert [d.name for d in manifest.orchestrations] == ["orch0", "orch1", "orch2"] + + def test_manifest_swarm_nodes_and_entry_point(self): + """Swarm nodes and entry_point resolved correctly.""" + agent = _mock_agent() + + swarm_node1 = Mock() + swarm_node1.node_id = "node1" + swarm_node1.executor = agent + + swarm_node2 = Mock() + swarm_node2.node_id = "node2" + swarm_node2.executor = Mock(spec=Agent) + + swarm = Mock(spec=Swarm) + swarm.nodes = {"node1": swarm_node1, "node2": swarm_node2} + swarm.entry_point = agent + swarm.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"swarm1": swarm}, + entry=swarm, + ) + + orch_desc = manifest.orchestrations[0] + assert len(orch_desc.nodes) == 2 + assert orch_desc.nodes[0].id == "node1" + assert orch_desc.nodes[0].kind == "agent" + assert orch_desc.entry_node_id == "node1" + + def test_manifest_swarm_entry_point_none_uses_first_node(self): + """Swarm with no entry_point uses first node.""" + agent = _mock_agent() + + swarm_node = Mock() + swarm_node.node_id = "first" + swarm_node.executor = agent + + swarm = Mock(spec=Swarm) + swarm.nodes = {"first": swarm_node} + swarm.entry_point = None + swarm.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"swarm1": swarm}, + entry=swarm, + ) + + assert manifest.orchestrations[0].entry_node_id == "first" + + def test_manifest_swarm_empty_nodes_entry_node_id_none(self): + """Swarm with no nodes has entry_node_id=None.""" + agent = _mock_agent() + + swarm = Mock(spec=Swarm) + swarm.nodes = {} + swarm.entry_point = None + swarm.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"swarm1": swarm}, + entry=swarm, + ) + + assert manifest.orchestrations[0].entry_node_id is None + + def test_manifest_graph_nodes_and_edges(self): + """Graph nodes and edges resolved correctly.""" + agent = _mock_agent() + + graph_node1 = Mock() + graph_node1.node_id = "gnode1" + graph_node1.executor = agent + + graph_node2 = Mock() + graph_node2.node_id = "gnode2" + graph_node2.executor = Mock(spec=Agent) + + graph_edge = Mock() + graph_edge.from_node = graph_node1 + graph_edge.to_node = graph_node2 + + graph = Mock(spec=Graph) + graph.nodes = {"gnode1": graph_node1, "gnode2": graph_node2} + graph.edges = {graph_edge} + graph.entry_points = {graph_node1} + graph.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"graph1": graph}, + entry=graph, + ) + + orch_desc = manifest.orchestrations[0] + assert len(orch_desc.nodes) == 2 + assert orch_desc.nodes[0].id == "gnode1" + assert orch_desc.nodes[0].kind == "agent" + assert orch_desc.edges is not None + assert len(orch_desc.edges) == 1 + assert orch_desc.edges[0].from_id == "gnode1" + assert orch_desc.edges[0].to_id == "gnode2" + + def test_manifest_graph_entry_node_id_single(self): + """Graph with single entry point.""" + agent = _mock_agent() + + graph_node = Mock() + graph_node.node_id = "entry" + + graph = Mock(spec=Graph) + graph.nodes = {"entry": graph_node} + graph.edges = set() + graph.entry_points = {graph_node} + graph.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"graph1": graph}, + entry=graph, + ) + + assert manifest.orchestrations[0].entry_node_id == "entry" + + def test_manifest_graph_entry_node_id_multiple_comma_joined(self): + """Graph with multiple entry points → comma-joined.""" + agent = _mock_agent() + + graph_node1 = Mock() + graph_node1.node_id = "entry1" + + graph_node2 = Mock() + graph_node2.node_id = "entry2" + + graph = Mock(spec=Graph) + graph.nodes = {"entry1": graph_node1, "entry2": graph_node2} + graph.edges = set() + # Use a list to preserve order for testing + graph.entry_points = [graph_node1, graph_node2] + graph.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"graph1": graph}, + entry=graph, + ) + + entry_id = manifest.orchestrations[0].entry_node_id + assert entry_id is not None + assert "entry1" in entry_id + assert "entry2" in entry_id + assert "," in entry_id + + def test_manifest_graph_entry_node_id_none_when_empty(self): + """Graph with no entry points → entry_node_id=None.""" + agent = _mock_agent() + + graph = Mock(spec=Graph) + graph.nodes = {} + graph.edges = set() + graph.entry_points = set() + graph.session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"graph1": graph}, + entry=graph, + ) + + assert manifest.orchestrations[0].entry_node_id is None + + def test_manifest_delegate_empty_topology(self): + """Delegate orchestration has empty topology.""" + agent = _mock_agent() + delegate = Mock(spec=Agent) + delegate._session_manager = None + + manifest = build_manifest( + agents={"agent1": agent}, + orchestrators={"delegate1": delegate}, + entry=delegate, + ) + + orch_desc = manifest.orchestrations[0] + assert orch_desc.nodes == [] + assert orch_desc.edges is None + assert orch_desc.entry_node_id is None + + +# ── first_session_id ───────────────────────────────────────────────────────── + + +def _file_descriptor(session_id: str) -> FileProviderDescriptor: + return FileProviderDescriptor(provider="file", session_id=session_id, storage_dir="/tmp") + + +class TestFirstSessionId: + """Tests for first_session_id.""" + + def _empty_entry(self) -> EntryDescriptor: + return EntryDescriptor(name="x", kind="agent") + + def _agent(self, sm: SessionManagerDescriptor | None) -> AgentDescriptor: + return AgentDescriptor( + name="a", + description=None, + model=ModelDescriptor(model_id=None, provider="P"), + session_manager=sm, + ) + + def test_returns_none_when_no_managers(self): + manifest = SessionManifest(entry=self._empty_entry()) + assert first_session_id(manifest) is None + + def test_returns_first_agent_session_id(self): + manifest = SessionManifest( + agents=[ + self._agent(None), + self._agent(_file_descriptor("sess-1")), + self._agent(_file_descriptor("sess-2")), + ], + entry=self._empty_entry(), + ) + assert first_session_id(manifest) == "sess-1" + + def test_falls_back_to_orchestration_when_no_agents_have_sm(self): + from strands_compose.types import OrchestrationDescriptor + + manifest = SessionManifest( + agents=[self._agent(None)], + orchestrations=[ + OrchestrationDescriptor( + name="o", + kind="swarm", + session_manager=_file_descriptor("orch-sess"), + ), + ], + entry=self._empty_entry(), + ) + assert first_session_id(manifest) == "orch-sess" diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index e8aaf73..3bb0de3 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1,4 +1,4 @@ -"""Tests for the EventType StrEnum.""" +"""Tests for the EventType StrEnum and Session Manifest models.""" from __future__ import annotations @@ -6,7 +6,20 @@ import pytest -from strands_compose.types import EventType +from strands_compose.types import ( + AgentCoreProviderDescriptor, + AgentDescriptor, + CustomProviderDescriptor, + EdgeRef, + EntryDescriptor, + EventType, + FileProviderDescriptor, + ModelDescriptor, + NodeRef, + OrchestrationDescriptor, + S3ProviderDescriptor, + SessionManifest, +) class TestEventTypeEnum: @@ -41,7 +54,18 @@ def test_string_comparison_works(self, member, expected_value): """StrEnum values compare equal to their plain string counterparts.""" assert EventType[member] == expected_value + def test_session_start_event_type_value(self): + """EventType.SESSION_START has the correct string value.""" + assert EventType.SESSION_START == "session_start" + assert isinstance(EventType.SESSION_START, str) + + def test_session_end_event_type_value(self): + """EventType.SESSION_END has the correct string value.""" + assert EventType.SESSION_END == "session_end" + assert isinstance(EventType.SESSION_END, str) + def test_all_members_present(self): + """All expected EventType members are present.""" expected = { "AGENT_START", "TOKEN", @@ -56,5 +80,449 @@ def test_all_members_present(self): "HANDOFF", "MULTIAGENT_START", "MULTIAGENT_COMPLETE", + "SESSION_START", + "SESSION_END", } assert set(EventType.__members__) == expected + + +class TestNodeRef: + """Tests for NodeRef Pydantic model.""" + + def test_node_ref_fields(self): + """NodeRef has the correct fields.""" + node = NodeRef(id="node-1", kind="agent") + assert node.id == "node-1" + assert node.kind == "agent" + + def test_node_ref_model_dump(self): + """NodeRef serializes correctly via model_dump.""" + node = NodeRef(id="node-1", kind="orchestration") + dumped = node.model_dump() + assert dumped == {"id": "node-1", "kind": "orchestration"} + + def test_node_ref_json_serializable(self): + """NodeRef is JSON-serializable.""" + import json + + node = NodeRef(id="node-1", kind="agent") + json_str = json.dumps(node.model_dump()) + assert json_str == '{"id": "node-1", "kind": "agent"}' + + +class TestEdgeRef: + """Tests for EdgeRef Pydantic model.""" + + def test_edge_ref_fields(self): + """EdgeRef has the correct fields.""" + edge = EdgeRef(from_id="node-1", to_id="node-2") + assert edge.from_id == "node-1" + assert edge.to_id == "node-2" + + def test_edge_ref_model_dump(self): + """EdgeRef serializes correctly via model_dump.""" + edge = EdgeRef(from_id="a", to_id="b") + dumped = edge.model_dump() + assert dumped == {"from_id": "a", "to_id": "b"} + + def test_edge_ref_json_serializable(self): + """EdgeRef is JSON-serializable.""" + import json + + edge = EdgeRef(from_id="node-1", to_id="node-2") + json_str = json.dumps(edge.model_dump()) + assert json_str == '{"from_id": "node-1", "to_id": "node-2"}' + + +class TestModelDescriptor: + """Tests for ModelDescriptor Pydantic model.""" + + def test_model_descriptor_fields(self): + """ModelDescriptor has the correct fields.""" + model = ModelDescriptor( + model_id="us.anthropic.claude-sonnet-4-6", + provider="strands.models.bedrock.BedrockModel", + ) + assert model.model_id == "us.anthropic.claude-sonnet-4-6" + assert model.provider == "strands.models.bedrock.BedrockModel" + + def test_model_descriptor_model_id_none(self): + """ModelDescriptor allows model_id to be None.""" + model = ModelDescriptor(model_id=None, provider="custom.CustomModel") + assert model.model_id is None + assert model.provider == "custom.CustomModel" + + def test_model_descriptor_model_dump(self): + """ModelDescriptor serializes correctly via model_dump.""" + model = ModelDescriptor(model_id="model-123", provider="Provider") + dumped = model.model_dump() + assert dumped == {"model_id": "model-123", "provider": "Provider"} + + def test_model_descriptor_json_serializable(self): + """ModelDescriptor is JSON-serializable.""" + import json + + model = ModelDescriptor(model_id=None, provider="Provider") + json_str = json.dumps(model.model_dump()) + assert "provider" in json_str + + +class TestFileProviderDescriptor: + """Tests for FileProviderDescriptor.""" + + def test_file_provider_descriptor_fields(self): + """FileProviderDescriptor has the correct fields.""" + desc = FileProviderDescriptor( + provider="file", + session_id="session-123", + storage_dir="/tmp/sessions", + ) + assert desc.provider == "file" + assert desc.session_id == "session-123" + assert desc.storage_dir == "/tmp/sessions" + + def test_file_provider_descriptor_model_dump(self): + """FileProviderDescriptor serializes correctly.""" + desc = FileProviderDescriptor( + provider="file", + session_id="s1", + storage_dir="/path", + ) + dumped = desc.model_dump() + assert dumped == { + "provider": "file", + "session_id": "s1", + "storage_dir": "/path", + } + + +class TestS3ProviderDescriptor: + """Tests for S3ProviderDescriptor.""" + + def test_s3_provider_descriptor_fields(self): + """S3ProviderDescriptor has the correct fields.""" + desc = S3ProviderDescriptor( + provider="s3", + session_id="session-123", + bucket="my-bucket", + prefix="sessions/", + ) + assert desc.provider == "s3" + assert desc.session_id == "session-123" + assert desc.bucket == "my-bucket" + assert desc.prefix == "sessions/" + + def test_s3_provider_descriptor_empty_prefix(self): + """S3ProviderDescriptor allows empty prefix.""" + desc = S3ProviderDescriptor( + provider="s3", + session_id="s1", + bucket="bucket", + prefix="", + ) + assert desc.prefix == "" + + +class TestAgentCoreProviderDescriptor: + """Tests for AgentCoreProviderDescriptor.""" + + def test_agentcore_provider_descriptor_fields(self): + """AgentCoreProviderDescriptor has the correct fields.""" + desc = AgentCoreProviderDescriptor( + provider="agentcore", + session_id="session-123", + memory_id="mem-456", + actor_id="actor-789", + ) + assert desc.provider == "agentcore" + assert desc.session_id == "session-123" + assert desc.memory_id == "mem-456" + assert desc.actor_id == "actor-789" + + +class TestCustomProviderDescriptor: + """Tests for CustomProviderDescriptor.""" + + def test_custom_provider_descriptor_fields(self): + """CustomProviderDescriptor has the correct fields.""" + desc = CustomProviderDescriptor( + provider="custom", + session_id="session-123", + class_name="my.module.CustomSessionManager", + ) + assert desc.provider == "custom" + assert desc.session_id == "session-123" + assert desc.class_name == "my.module.CustomSessionManager" + + def test_custom_provider_descriptor_session_id_none(self): + """CustomProviderDescriptor allows session_id to be None.""" + desc = CustomProviderDescriptor( + provider="custom", + session_id=None, + class_name="my.module.CustomSessionManager", + ) + assert desc.session_id is None + + +class TestAgentDescriptor: + """Tests for AgentDescriptor Pydantic model.""" + + def test_agent_descriptor_fields(self): + """AgentDescriptor has the correct fields.""" + model = ModelDescriptor(model_id="m1", provider="Provider") + agent = AgentDescriptor( + name="researcher", + description="Researches topics", + model=model, + session_manager=None, + ) + assert agent.name == "researcher" + assert agent.description == "Researches topics" + assert agent.model == model + assert agent.session_manager is None + + def test_agent_descriptor_description_none(self): + """AgentDescriptor allows description to be None.""" + model = ModelDescriptor(model_id=None, provider="Provider") + agent = AgentDescriptor( + name="agent", + description=None, + model=model, + session_manager=None, + ) + assert agent.description is None + + def test_agent_descriptor_with_session_manager(self): + """AgentDescriptor can include a session manager.""" + model = ModelDescriptor(model_id="m1", provider="Provider") + sm = FileProviderDescriptor( + provider="file", + session_id="s1", + storage_dir="/tmp", + ) + agent = AgentDescriptor( + name="agent", + description="desc", + model=model, + session_manager=sm, + ) + assert agent.session_manager == sm + + def test_agent_descriptor_model_dump(self): + """AgentDescriptor serializes correctly.""" + model = ModelDescriptor(model_id="m1", provider="Provider") + agent = AgentDescriptor( + name="agent", + description="desc", + model=model, + session_manager=None, + ) + dumped = agent.model_dump() + assert dumped["name"] == "agent" + assert dumped["description"] == "desc" + assert dumped["model"]["model_id"] == "m1" + assert dumped["session_manager"] is None + + +class TestOrchestrationDescriptor: + """Tests for OrchestrationDescriptor Pydantic model.""" + + def test_orchestration_descriptor_fields(self): + """OrchestrationDescriptor has the correct fields.""" + orch = OrchestrationDescriptor( + name="main", + kind="swarm", + session_manager=None, + nodes=[NodeRef(id="n1", kind="agent")], + edges=None, + entry_node_id="n1", + ) + assert orch.name == "main" + assert orch.kind == "swarm" + assert orch.session_manager is None + assert len(orch.nodes) == 1 + assert orch.edges is None + assert orch.entry_node_id == "n1" + + def test_orchestration_descriptor_empty_defaults(self): + """OrchestrationDescriptor has correct default values.""" + orch = OrchestrationDescriptor( + name="main", + kind="delegate", + session_manager=None, + ) + assert orch.nodes == [] + assert orch.edges is None + assert orch.entry_node_id is None + + def test_orchestration_descriptor_with_edges(self): + """OrchestrationDescriptor can include edges.""" + edges = [EdgeRef(from_id="n1", to_id="n2")] + orch = OrchestrationDescriptor( + name="graph", + kind="graph", + session_manager=None, + nodes=[NodeRef(id="n1", kind="agent"), NodeRef(id="n2", kind="agent")], + edges=edges, + entry_node_id="n1", + ) + assert orch.edges == edges + + +class TestEntryDescriptor: + """Tests for EntryDescriptor Pydantic model.""" + + def test_entry_descriptor_fields(self): + """EntryDescriptor has the correct fields.""" + entry = EntryDescriptor(name="main", kind="orchestration") + assert entry.name == "main" + assert entry.kind == "orchestration" + + def test_entry_descriptor_agent_kind(self): + """EntryDescriptor can have kind='agent'.""" + entry = EntryDescriptor(name="researcher", kind="agent") + assert entry.kind == "agent" + + def test_entry_descriptor_model_dump(self): + """EntryDescriptor serializes correctly.""" + entry = EntryDescriptor(name="main", kind="orchestration") + dumped = entry.model_dump() + assert dumped == {"name": "main", "kind": "orchestration"} + + +class TestSessionManifest: + """Tests for SessionManifest Pydantic model.""" + + def test_session_manifest_fields(self): + """SessionManifest has the correct fields.""" + entry = EntryDescriptor(name="main", kind="agent") + manifest = SessionManifest( + agents=[], + orchestrations=[], + entry=entry, + ) + assert manifest.agents == [] + assert manifest.orchestrations == [] + assert manifest.entry == entry + + def test_session_manifest_empty_defaults(self): + """SessionManifest defaults agents and orchestrations to empty lists.""" + entry = EntryDescriptor(name="main", kind="agent") + manifest = SessionManifest(entry=entry) + assert manifest.agents == [] + assert manifest.orchestrations == [] + + def test_session_manifest_with_agents(self): + """SessionManifest can include agents.""" + model = ModelDescriptor(model_id="m1", provider="Provider") + agent = AgentDescriptor( + name="researcher", + description="desc", + model=model, + session_manager=None, + ) + entry = EntryDescriptor(name="researcher", kind="agent") + manifest = SessionManifest( + agents=[agent], + orchestrations=[], + entry=entry, + ) + assert len(manifest.agents) == 1 + assert manifest.agents[0].name == "researcher" + + def test_session_manifest_with_orchestrations(self): + """SessionManifest can include orchestrations.""" + orch = OrchestrationDescriptor( + name="main", + kind="swarm", + session_manager=None, + ) + entry = EntryDescriptor(name="main", kind="orchestration") + manifest = SessionManifest( + agents=[], + orchestrations=[orch], + entry=entry, + ) + assert len(manifest.orchestrations) == 1 + assert manifest.orchestrations[0].name == "main" + + def test_session_manifest_model_dump(self): + """SessionManifest serializes correctly via model_dump.""" + entry = EntryDescriptor(name="main", kind="agent") + manifest = SessionManifest( + agents=[], + orchestrations=[], + entry=entry, + ) + dumped = manifest.model_dump() + assert dumped["agents"] == [] + assert dumped["orchestrations"] == [] + assert dumped["entry"]["name"] == "main" + assert dumped["entry"]["kind"] == "agent" + + def test_session_manifest_json_serializable(self): + """SessionManifest is JSON-serializable.""" + import json + + entry = EntryDescriptor(name="main", kind="agent") + manifest = SessionManifest( + agents=[], + orchestrations=[], + entry=entry, + ) + json_str = json.dumps(manifest.model_dump()) + assert "main" in json_str + assert "agent" in json_str + + def test_session_manifest_complex_example(self): + """SessionManifest works with a complex multi-agent setup.""" + model1 = ModelDescriptor(model_id="m1", provider="Provider1") + model2 = ModelDescriptor(model_id="m2", provider="Provider2") + + agent1 = AgentDescriptor( + name="researcher", + description="Researches topics", + model=model1, + session_manager=FileProviderDescriptor( + provider="file", + session_id="s1", + storage_dir="/tmp", + ), + ) + agent2 = AgentDescriptor( + name="writer", + description="Writes content", + model=model2, + session_manager=None, + ) + + orch = OrchestrationDescriptor( + name="main", + kind="swarm", + session_manager=None, + nodes=[ + NodeRef(id="researcher", kind="agent"), + NodeRef(id="writer", kind="agent"), + ], + edges=None, + entry_node_id="researcher", + ) + + entry = EntryDescriptor(name="main", kind="orchestration") + + manifest = SessionManifest( + agents=[agent1, agent2], + orchestrations=[orch], + entry=entry, + ) + + assert len(manifest.agents) == 2 + assert len(manifest.orchestrations) == 1 + assert manifest.entry.name == "main" + + # Verify it's JSON-serializable + import json + + json_str = json.dumps(manifest.model_dump()) + assert "researcher" in json_str + assert "writer" in json_str diff --git a/tests/unit/test_wire.py b/tests/unit/test_wire.py index 678eb2e..3429370 100644 --- a/tests/unit/test_wire.py +++ b/tests/unit/test_wire.py @@ -1,7 +1,9 @@ -"""Tests for core.wire — StreamEvent.""" +"""Tests for strands_compose.wire — StreamEvent dataclass.""" from __future__ import annotations +from datetime import datetime, timedelta, timezone + from strands_compose.types import EventType from strands_compose.wire import StreamEvent @@ -28,8 +30,6 @@ def test_from_dict_round_trips_timestamp(self): class TestStreamEventEquality: def test_eq_ignores_timestamp(self): - from datetime import datetime, timedelta, timezone - t1 = datetime.now(tz=timezone.utc) t2 = t1 + timedelta(seconds=5) e1 = StreamEvent(type=EventType.TOKEN, agent_name="a", timestamp=t1, data={"text": "hi"}) diff --git a/uv.lock b/uv.lock index 4b91196..230c213 100644 --- a/uv.lock +++ b/uv.lock @@ -1782,7 +1782,7 @@ openai = [ [[package]] name = "strands-compose" -version = "0.3.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "mcp" },