diff --git a/tests/test_envelope_recovery.py b/tests/test_envelope_recovery.py new file mode 100644 index 0000000..d48a610 --- /dev/null +++ b/tests/test_envelope_recovery.py @@ -0,0 +1,106 @@ +"""Coverage tests for ``runtime.graph._try_recover_envelope_from_raw`` (graph.py:583-610). + +The recovery helper is invoked by the agent runner when LangGraph's +structured-output pass raises ``OutputParserException`` — it tries +several candidate substrings to dig an :class:`AgentTurnOutput` out of +free-form LLM text. The function is pure and behaves identically for +every input, so a table-driven test suite pins each branch. +""" +from __future__ import annotations + +import json + +import pytest + +from runtime.agents.turn_output import AgentTurnOutput +from runtime.graph import _try_recover_envelope_from_raw + + +def _envelope_dict(*, content: str = "ok", confidence: float = 0.85, + rationale: str = "stub", signal: str | None = None) -> dict: + return { + "content": content, + "confidence": confidence, + "confidence_rationale": rationale, + "signal": signal, + } + + +def _envelope_json(**overrides) -> str: + return json.dumps(_envelope_dict(**overrides)) + + +class TestEmptyInput: + @pytest.mark.parametrize("raw", ["", " ", "\n\n \t\n"]) + def test_empty_or_whitespace_returns_none(self, raw): + assert _try_recover_envelope_from_raw(raw) is None + + +class TestPlainJsonInput: + def test_valid_envelope_json_parses(self): + out = _try_recover_envelope_from_raw(_envelope_json(content="hello")) + assert isinstance(out, AgentTurnOutput) + assert out.content == "hello" + + def test_valid_envelope_with_signal(self): + out = _try_recover_envelope_from_raw(_envelope_json(signal="reconcile")) + assert out is not None + assert out.signal == "reconcile" + + +class TestMarkdownFencedJson: + def test_fenced_with_json_tag(self): + raw = f"```json\n{_envelope_json()}\n```" + out = _try_recover_envelope_from_raw(raw) + assert isinstance(out, AgentTurnOutput) + + def test_fenced_without_json_tag(self): + raw = f"```\n{_envelope_json(confidence=0.42)}\n```" + out = _try_recover_envelope_from_raw(raw) + assert out is not None + assert out.confidence == 0.42 + + def test_fenced_with_surrounding_chatter(self): + raw = ( + "Here is my structured response:\n\n" + f"```json\n{_envelope_json(content='fenced')}\n```\n\n" + "Hope that helps!" + ) + out = _try_recover_envelope_from_raw(raw) + assert out is not None + assert out.content == "fenced" + + +class TestGreedyBraceMatch: + def test_chatter_then_json_then_chatter(self): + # No fences — should fall through to the greedy first-{...-last-} scan. + raw = ( + f"Sure, here's the answer: {_envelope_json(content='greedy')} " + "Let me know if you need more!" + ) + out = _try_recover_envelope_from_raw(raw) + assert out is not None + assert out.content == "greedy" + + +class TestUnrecoverableInput: + def test_invalid_json_returns_none(self): + assert _try_recover_envelope_from_raw("{not valid json}") is None + + def test_no_braces_returns_none(self): + assert _try_recover_envelope_from_raw("Just a plain sentence.") is None + + def test_json_array_not_dict_returns_none(self): + # Greedy match would still find a substring, but `[1, 2, 3]` + # has no braces. Use a real array text. + assert _try_recover_envelope_from_raw('["a", "b"]') is None + + def test_valid_dict_missing_required_fields_returns_none(self): + # `{"foo": "bar"}` parses but fails AgentTurnOutput validation. + assert _try_recover_envelope_from_raw('{"foo": "bar"}') is None + + def test_dict_with_invalid_field_types_returns_none(self): + # confidence must be 0..1; this should fail validation on every candidate. + assert _try_recover_envelope_from_raw( + '{"content": "x", "confidence": 5.0, "confidence_rationale": "y"}' + ) is None diff --git a/tests/test_handle_agent_failure.py b/tests/test_handle_agent_failure.py new file mode 100644 index 0000000..30885e4 --- /dev/null +++ b/tests/test_handle_agent_failure.py @@ -0,0 +1,145 @@ +"""Coverage tests for ``runtime.graph._handle_agent_failure`` (graph.py:613-644). + +This helper is invoked by the agent runner when an agent body raises a +non-pause exception (anything other than ``GraphInterrupt``). It reloads +the session (absorbing partial tool writes), appends a failure +``AgentRun``, marks the session ``status='error'``, persists, and returns +the state dict the LangGraph node yields. +""" +from __future__ import annotations + +import pytest + +from runtime.config import EmbeddingConfig, MetadataConfig, ProviderConfig +from runtime.graph import _handle_agent_failure +from runtime.state import AgentRun, Session +from runtime.storage.embeddings import build_embedder +from runtime.storage.engine import build_engine +from runtime.storage.models import Base +from runtime.storage.session_store import SessionStore + + +@pytest.fixture +def store(tmp_path) -> SessionStore: + eng = build_engine(MetadataConfig(url=f"sqlite:///{tmp_path}/test.db")) + Base.metadata.create_all(eng) + embedder = build_embedder( + EmbeddingConfig(provider="s", model="x", dim=1024), + {"s": ProviderConfig(kind="stub")}, + ) + return SessionStore(engine=eng, embedder=embedder) + + +def _seed_session(store: SessionStore, *, agents_run: list[AgentRun] | None = None) -> Session: + """Create + persist a baseline session and return it.""" + inc = store.create(query="probe", environment="dev", + reporter_id="u1", reporter_team="t") + if agents_run: + inc.agents_run.extend(agents_run) + store.save(inc) + inc = store.load(inc.id) + return inc + + +class TestHappyPath: + def test_failure_run_appended_and_status_set_to_error(self, store): + inc = _seed_session(store) + result = _handle_agent_failure( + skill_name="triage", + started_at="2026-05-15T00:00:00Z", + exc=RuntimeError("upstream blew up"), + inc_id=inc.id, + store=store, + fallback=inc, + ) + # Returned state dict + assert result["last_agent"] == "triage" + assert result["next_route"] is None + assert result["error"] == "upstream blew up" + assert isinstance(result["session"], Session) + # Persisted session reflects the failure + loaded = store.load(inc.id) + assert loaded.status == "error" + assert len(loaded.agents_run) == 1 + run = loaded.agents_run[0] + assert run.agent == "triage" + assert run.summary == "agent failed: upstream blew up" + + def test_appends_to_existing_run_history(self, store): + prior = AgentRun( + agent="intake", + started_at="2026-05-15T00:00:00Z", + ended_at="2026-05-15T00:00:01Z", + summary="completed: routed to triage", + ) + inc = _seed_session(store, agents_run=[prior]) + _handle_agent_failure( + skill_name="triage", + started_at="2026-05-15T00:00:02Z", + exc=TimeoutError("provider hung"), + inc_id=inc.id, + store=store, + fallback=inc, + ) + loaded = store.load(inc.id) + assert [r.agent for r in loaded.agents_run] == ["intake", "triage"] + assert "agent failed: provider hung" in loaded.agents_run[1].summary + + def test_preserves_partial_tool_writes_via_reload(self, store): + """If a tool wrote to the session before the agent raised, + the reload-then-append pattern must keep that tool's write.""" + inc = _seed_session(store) + # Simulate a tool write that already persisted. + from runtime.state import ToolCall + inc.tool_calls.append(ToolCall( + agent="triage", + tool="lookup_similar_incidents", + args={"query": "x"}, + result={"hits": []}, + ts="2026-05-15T00:00:00Z", + )) + store.save(inc) + # Caller's stale `fallback` reference does not have the tool call. + stale = inc.model_copy(deep=True) + stale.tool_calls = [] + _handle_agent_failure( + skill_name="triage", + started_at="2026-05-15T00:00:02Z", + exc=RuntimeError("oops"), + inc_id=inc.id, + store=store, + fallback=stale, + ) + loaded = store.load(inc.id) + # Tool call survived because _handle_agent_failure reloaded + # before appending its failure run. + assert len(loaded.tool_calls) == 1 + assert loaded.tool_calls[0].tool == "lookup_similar_incidents" + + +class TestFallbackPath: + def test_uses_fallback_when_session_missing_from_store(self, store): + # Session never persisted; store.load(inc_id) raises FileNotFoundError. + from runtime.state import Session + ghost = Session( + id="INC-20260515-999", + status="in_progress", + created_at="2026-05-15T00:00:00Z", + updated_at="2026-05-15T00:00:00Z", + ) + result = _handle_agent_failure( + skill_name="intake", + started_at="2026-05-15T00:00:00Z", + exc=RuntimeError("dropped on the floor"), + inc_id="INC-20260515-999", + store=store, + fallback=ghost, + ) + # The fallback was used, the failure run was appended, + # and the now-populated fallback was saved. + assert result["session"].status == "error" + loaded = store.load("INC-20260515-999") + assert loaded.id == "INC-20260515-999" + assert loaded.status == "error" + assert len(loaded.agents_run) == 1 + assert loaded.agents_run[0].agent == "intake" diff --git a/tests/test_llm_stub_structured_output.py b/tests/test_llm_stub_structured_output.py new file mode 100644 index 0000000..eba7396 --- /dev/null +++ b/tests/test_llm_stub_structured_output.py @@ -0,0 +1,72 @@ +"""Coverage tests for ``StubChatModel.with_structured_output`` (llm.py:141-160, 171-177). + +The structured-output runnable was previously only exercised indirectly +via ``langchain.agents.create_agent``. These tests pin the direct +contract: the stub returns a Runnable-like that yields a valid schema +instance per ``invoke`` / ``ainvoke``, populated from the canned text and +``stub_envelope_*`` parameters. + +The permissive ``model_validate`` fallback (lines 161-169) is genuinely +defensive: pydantic v2's ``model_validate`` internally calls ``__init__`` +too, so any schema whose constructor raises also fails the fallback. +The fallback exists for hypothetical schemas with custom +``__pydantic_validator__`` overrides, which the framework doesn't ship +and tests can't construct without monkey-patching pydantic internals. +""" +from __future__ import annotations + +import pytest + +from runtime.agents.turn_output import AgentTurnOutput +from runtime.llm import StubChatModel + + +def _stub(*, confidence: float = 0.85, rationale: str = "stub rationale", + signal: str | None = None, role: str = "intake", + canned: dict[str, str] | None = None) -> StubChatModel: + return StubChatModel( + role=role, + canned_responses=canned if canned is not None else {role: "stub body text"}, + stub_envelope_confidence=confidence, + stub_envelope_rationale=rationale, + stub_envelope_signal=signal, + ) + + +class TestStubStructuredOutputHappyPath: + """Happy path: schema(...) keyword constructor succeeds.""" + + def test_invoke_returns_schema_instance(self): + runnable = _stub().with_structured_output(AgentTurnOutput) + out = runnable.invoke("any input") + assert isinstance(out, AgentTurnOutput) + assert out.content == "stub body text" + assert out.confidence == 0.85 + assert out.confidence_rationale == "stub rationale" + assert out.signal is None + + @pytest.mark.asyncio + async def test_ainvoke_returns_schema_instance(self): + runnable = _stub(confidence=0.42, rationale="hedge", signal="retry").with_structured_output(AgentTurnOutput) + out = await runnable.ainvoke("any input") + assert isinstance(out, AgentTurnOutput) + assert out.confidence == 0.42 + assert out.confidence_rationale == "hedge" + assert out.signal == "retry" + + def test_canned_response_missing_uses_default_marker(self): + runnable = _stub(role="ghost", canned={}).with_structured_output(AgentTurnOutput) + out = runnable.invoke("x") + assert out.content.startswith("[stub:ghost]") + + def test_include_raw_kwarg_is_accepted(self): + # langchain passes include_raw=True/False on the call site; the stub + # accepts the kwarg but doesn't change behaviour. + runnable = _stub().with_structured_output(AgentTurnOutput, include_raw=True) + assert runnable.invoke("x").content == "stub body text" + + def test_extra_kwargs_are_swallowed(self): + runnable = _stub().with_structured_output(AgentTurnOutput, method="json_mode", strict=True) + assert runnable.invoke("x").confidence == 0.85 + + diff --git a/tests/test_orchestrator_extract_last_error.py b/tests/test_orchestrator_extract_last_error.py new file mode 100644 index 0000000..d76c52b --- /dev/null +++ b/tests/test_orchestrator_extract_last_error.py @@ -0,0 +1,129 @@ +"""Coverage tests for ``Orchestrator._extract_last_error`` (orchestrator.py:945-998). + +Pure mapping from a Session's failed-AgentRun summary string to a +representative typed exception (used by :func:`runtime.policy.should_retry`'s +``isinstance`` checks). Table-driven across every branch. +""" +from __future__ import annotations + +import pydantic +import pytest + +from runtime.agents.turn_output import EnvelopeMissingError +from runtime.orchestrator import Orchestrator +from runtime.state import AgentRun, Session + + +def _session_with_failures(*summaries: str) -> Session: + """Build a Session with the given run summaries (oldest first). + + Empty string defaults to a non-failure run (summary 'completed: x'). + """ + runs = [] + for i, summary in enumerate(summaries): + runs.append(AgentRun( + agent=f"agent-{i}", + started_at="2026-05-15T00:00:00Z", + ended_at="2026-05-15T00:00:01Z", + summary=summary, + )) + return Session( + id="INC-test", + status="error", + created_at="2026-05-15T00:00:00Z", + updated_at="2026-05-15T00:00:01Z", + agents_run=runs, + ) + + +class TestNoFailedRun: + def test_empty_runs_returns_none(self): + inc = _session_with_failures() + assert Orchestrator._extract_last_error(inc) is None + + def test_only_successful_runs_returns_none(self): + inc = _session_with_failures( + "completed: triage routed to investigate", + "completed: investigate found root cause", + ) + assert Orchestrator._extract_last_error(inc) is None + + def test_run_with_empty_summary_returns_none(self): + inc = _session_with_failures("") + assert Orchestrator._extract_last_error(inc) is None + + +class TestEnvelopeMissingMapping: + def test_envelope_missing_error_matched(self): + inc = _session_with_failures( + "agent failed: EnvelopeMissingError: confidence (agent=intake)", + ) + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, EnvelopeMissingError) + assert err.agent == "agent-0" + assert err.field == "confidence" + + +class TestValidationErrorMapping: + def test_capitalised_validation_error(self): + inc = _session_with_failures("agent failed: ValidationError on field foo") + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, pydantic.ValidationError) + + def test_lowercase_validation_error(self): + inc = _session_with_failures("agent failed: pydantic raised a validation error here") + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, pydantic.ValidationError) + + +class TestTimeoutMapping: + @pytest.mark.parametrize("body", [ + "agent failed: TimeoutError: provider hung", + "agent failed: request timed out after 30s", + "agent failed: asyncio.TimeoutError", + ]) + def test_timeout_variants_match(self, body): + inc = _session_with_failures(body) + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, TimeoutError) + + +class TestOSErrorMapping: + @pytest.mark.parametrize("body", [ + "agent failed: OSError: too many open files", + "agent failed: ConnectionError: refused", + ]) + def test_oserror_variants_match(self, body): + inc = _session_with_failures(body) + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, OSError) + + +class TestRuntimeErrorFallback: + def test_unknown_failure_returns_runtime_error(self): + inc = _session_with_failures("agent failed: KeyError: something weird") + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, RuntimeError) + + +class TestNewestFailureWins: + def test_reversed_iteration_picks_newest_failure(self): + # First failed (older) is OSError, second failed (newer) is Timeout. + # Reversed iteration should hit the timeout first and return it. + inc = _session_with_failures( + "agent failed: OSError: stale", + "agent failed: TimeoutError: fresh", + ) + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, TimeoutError) + assert "fresh" in str(err) + + def test_skips_non_failure_summaries(self): + # Only one failure in the middle of successes — should still be found. + inc = _session_with_failures( + "completed: triage ok", + "agent failed: OSError: middle failure", + "completed: another success", + ) + err = Orchestrator._extract_last_error(inc) + assert isinstance(err, OSError) diff --git a/tests/test_retry_session_locked_post_policy.py b/tests/test_retry_session_locked_post_policy.py new file mode 100644 index 0000000..852771d --- /dev/null +++ b/tests/test_retry_session_locked_post_policy.py @@ -0,0 +1,182 @@ +"""Coverage tests for ``Orchestrator._retry_session_locked`` post-policy +execution path (orchestrator.py:1552-1587). + +The retry method has three early-exit branches (session-not-found, +not-in-error-state, policy-rejected) that other tests already cover. +This file exercises what happens *after* the policy accepts: the +failed-AgentRun filter, retry_count + active_thread_id pinning, the +graph re-stream, and the pause-vs-finalize fork at the tail. +""" +from __future__ import annotations + +import pytest +from sqlalchemy import create_engine + +from runtime.config import OrchestratorConfig +from runtime.locks import SessionLockRegistry +from runtime.orchestrator import Orchestrator +from runtime.state import AgentRun +from runtime.storage.models import Base +from runtime.storage.session_store import SessionStore + + +# --------------------------------------------------------------------------- +# Stub orchestrator that pulls in the real `_retry_session_locked` body but +# substitutes the surrounding integration points (graph, finalize, pause). +# --------------------------------------------------------------------------- +class _StubOrch: + """Minimum surface needed by ``_retry_session_locked`` lines 1515-1588.""" + + def __init__(self, store: SessionStore, *, paused: bool, finalized: str | None, + graph_events: list[dict] | None = None) -> None: + self.store = store + self.cfg = type("_Cfg", (), {"orchestrator": OrchestratorConfig()})() + self._locks = SessionLockRegistry() + self._retries_in_flight: set[str] = set() + self._paused = paused + self._finalized = finalized + self._graph_events = graph_events or [ + {"event": "on_chain_start", "name": "intake"}, + {"event": "on_chain_end", "name": "intake"}, + ] + self._streamed_state = None # captures the GraphState handed to astream + # Stub graph object + outer = self + + class _Graph: + async def astream_events(self, state, *, version, config): + outer._streamed_state = state + outer._streamed_config = config + for ev in outer._graph_events: + yield ev + + self.graph = _Graph() + + # Pull real method bodies in directly. The static helpers must be + # re-wrapped because reading them off the class strips the + # staticmethod descriptor and they would bind as instance methods. + retry_session = Orchestrator.retry_session + _retry_session_locked = Orchestrator._retry_session_locked + _to_ui_event = staticmethod(Orchestrator._to_ui_event) + _extract_last_error = staticmethod(Orchestrator._extract_last_error) + _extract_last_confidence = staticmethod(Orchestrator._extract_last_confidence) + + def _thread_config(self, sid: str) -> dict: + return {"configurable": {"thread_id": sid}} + + async def _is_graph_paused(self, sid: str) -> bool: + return self._paused + + async def _finalize_session_status_async(self, sid: str) -> str | None: + return self._finalized + + +@pytest.fixture +def store(tmp_path) -> SessionStore: + eng = create_engine(f"sqlite:///{tmp_path/'t.db'}") + Base.metadata.create_all(eng) + return SessionStore(engine=eng) + + +def _seed_failed_session(store: SessionStore) -> str: + inc = store.create(query="probe", environment="dev", + reporter_id="u", reporter_team="t") + inc.status = "error" + inc.agents_run = [ + AgentRun( + agent="intake", + started_at="2026-05-15T00:00:00Z", + ended_at="2026-05-15T00:00:01Z", + summary="completed: routed", + ), + AgentRun( + agent="triage", + started_at="2026-05-15T00:00:02Z", + ended_at="2026-05-15T00:00:03Z", + summary="agent failed: TimeoutError: provider hung", + ), + ] + store.save(inc) + return inc.id + + +@pytest.mark.asyncio +async def test_retry_completes_drops_failed_runs_bumps_thread_finalizes(store): + sid = _seed_failed_session(store) + orch = _StubOrch(store, paused=False, finalized="resolved") + + events = [] + async for ev in orch.retry_session(sid): + events.append(ev) + + kinds = [e["event"] for e in events] + # Started + at least one streamed event + status_auto_finalized + completed. + assert "retry_started" in kinds + assert "status_auto_finalized" in kinds + assert kinds[-1] == "retry_completed" + assert "session_paused" not in kinds # finalize branch took it + + # Persisted state reflects all post-policy mutations. + inc = store.load(sid) + # Failed AgentRun was filtered; only the successful intake run remains + # in the pre-streamed timeline (the stub graph doesn't add new runs). + assert all(not (r.summary or "").startswith("agent failed:") for r in inc.agents_run) + assert [r.agent for r in inc.agents_run] == ["intake"] + # retry_count bumped from 0 to 1. + assert inc.extra_fields.get("retry_count") == 1 + # active_thread_id pinned to retry-1. + assert inc.extra_fields.get("active_thread_id") == f"{sid}:retry-1" + # Status flipped back to in_progress when entering the retry stream. + # (Subsequent _finalize would normally update it, but the stub returns + # "resolved" without writing back to the DB — that's the orchestrator's + # _finalize_session_status_async responsibility, which we stubbed out.) + assert inc.status == "in_progress" + + # The status_auto_finalized event carries the stub's "resolved" status. + finalized_event = next(e for e in events if e["event"] == "status_auto_finalized") + assert finalized_event["status"] == "resolved" + + +@pytest.mark.asyncio +async def test_retry_pause_branch_yields_session_paused_not_finalized(store): + sid = _seed_failed_session(store) + orch = _StubOrch(store, paused=True, finalized="should-not-be-used") + + events = [] + async for ev in orch.retry_session(sid): + events.append(ev) + + kinds = [e["event"] for e in events] + assert "retry_started" in kinds + assert "session_paused" in kinds + assert "status_auto_finalized" not in kinds # pause branch took it + assert kinds[-1] == "retry_completed" + + +@pytest.mark.asyncio +async def test_retry_increments_retry_count_across_calls(store): + sid = _seed_failed_session(store) + orch = _StubOrch(store, paused=False, finalized=None) + + # First retry: 0 -> 1 + async for _ in orch.retry_session(sid): + pass + assert store.load(sid).extra_fields["retry_count"] == 1 + + # Reset the session back to error so a second retry is allowed. + inc = store.load(sid) + inc.status = "error" + inc.agents_run.append(AgentRun( + agent="triage", + started_at="2026-05-15T00:00:10Z", + ended_at="2026-05-15T00:00:11Z", + summary="agent failed: TimeoutError: still broken", + )) + store.save(inc) + + # Second retry: 1 -> 2; thread id pin reflects the new count. + async for _ in orch.retry_session(sid): + pass + inc = store.load(sid) + assert inc.extra_fields["retry_count"] == 2 + assert inc.extra_fields["active_thread_id"] == f"{sid}:retry-2" diff --git a/tests/test_service_run_exception_branches.py b/tests/test_service_run_exception_branches.py new file mode 100644 index 0000000..0bad067 --- /dev/null +++ b/tests/test_service_run_exception_branches.py @@ -0,0 +1,165 @@ +"""Coverage tests for ``OrchestratorService.start_session._run`` exception +branches (service.py:541-568). + +The inner ``_run`` task supervises a single graph turn. Three exception +classes get distinct treatment: + + 1. ``asyncio.CancelledError`` — propagated as-is. + 2. ``GraphInterrupt`` — propagated WITHOUT marking + ``registry.status='error'`` (HITL pause is not a failure). + 3. Anything else — registry entry is stamped ``status='error'`` so a + concurrent snapshot observes the failure before the done-callback + evicts the entry. +""" +from __future__ import annotations + +import asyncio + +import pytest +from langgraph.errors import GraphInterrupt + +from runtime.config import ( + AppConfig, + LLMConfig, + MCPConfig, + MCPServerConfig, + MetadataConfig, + Paths, + StorageConfig, +) +from runtime.service import OrchestratorService + + +@pytest.fixture +def cfg(tmp_path): + """AppConfig wired to in-process MCP servers so the example skills + (which reference ``get_logs`` etc.) pass validation.""" + return AppConfig( + llm=LLMConfig.stub(), + mcp=MCPConfig(servers=[ + MCPServerConfig(name="local_inc", transport="in_process", + module="examples.incident_management.mcp_server", + category="incident_management"), + MCPServerConfig(name="local_obs", transport="in_process", + module="examples.incident_management.mcp_servers.observability", + category="observability"), + MCPServerConfig(name="local_rem", transport="in_process", + module="examples.incident_management.mcp_servers.remediation", + category="remediation"), + MCPServerConfig(name="local_user", transport="in_process", + module="examples.incident_management.mcp_servers.user_context", + category="user_context"), + ]), + storage=StorageConfig( + metadata=MetadataConfig(url=f"sqlite:///{tmp_path}/test.db") + ), + paths=Paths( + skills_dir="examples/incident_management/skills", + incidents_dir=str(tmp_path), + ), + ) + + +@pytest.fixture +def service(cfg): + """Started OrchestratorService; teardown calls shutdown().""" + svc = OrchestratorService.get_or_create(cfg) + svc.start() + try: + yield svc + finally: + svc.shutdown() + + +def _trap_graph_with(service: OrchestratorService, exc_factory): + """Replace ``orch.graph.ainvoke`` with a coroutine that captures the + in-flight registry entry then raises the supplied exception. + + Returns the captured-entries list (populated by the time the task + has run); the caller asserts on ``captured[0].status``. + """ + captured: list = [] + + async def _setup_trap(): + orch = await service._ensure_orchestrator() + + async def _trapped(state, *, config): + sid = state["session"].id + entry = service._registry.get(sid) + captured.append(entry) + raise exc_factory() + + orch.graph.ainvoke = _trapped + return None + + service.submit_and_wait(_setup_trap(), timeout=10.0) + return captured + + +class TestGraphInterruptBranch: + """GraphInterrupt must NOT flip registry.status to 'error'.""" + + def test_pause_keeps_status_running(self, service): + captured = _trap_graph_with( + service, + lambda: GraphInterrupt(), + ) + sid = service.start_session(query="probe", state_overrides={"environment": "dev"}) + # Wait for the background task to finish (it will raise + # GraphInterrupt, which is caught by the Task and surfaces on + # ``await``). + async def _await_done(): + entry = service._registry.get(sid) or (captured[0] if captured else None) + if entry is not None and entry.task is not None: + with __import__("contextlib").suppress(BaseException): + await entry.task + service.submit_and_wait(_await_done(), timeout=10.0) + + assert captured, "trapped ainvoke never ran" + # Phase 11 / D-11-04: pause must NOT mark the entry as failed. + assert captured[0].status == "running" + + +class TestGenericExceptionBranch: + """Non-pause exceptions must flip registry.status to 'error'.""" + + def test_generic_failure_marks_status_error(self, service): + captured = _trap_graph_with( + service, + lambda: ValueError("boom"), + ) + sid = service.start_session(query="probe", state_overrides={"environment": "dev"}) + + async def _await_done(): + entry = service._registry.get(sid) or (captured[0] if captured else None) + if entry is not None and entry.task is not None: + with __import__("contextlib").suppress(BaseException): + await entry.task + service.submit_and_wait(_await_done(), timeout=10.0) + + assert captured, "trapped ainvoke never ran" + # Generic exception → status flipped before the re-raise. + assert captured[0].status == "error" + + +class TestCancelledErrorBranch: + """CancelledError must propagate without modifying registry.status.""" + + def test_cancellation_does_not_mark_error(self, service): + captured = _trap_graph_with( + service, + lambda: asyncio.CancelledError(), + ) + sid = service.start_session(query="probe", state_overrides={"environment": "dev"}) + + async def _await_done(): + entry = service._registry.get(sid) or (captured[0] if captured else None) + if entry is not None and entry.task is not None: + with __import__("contextlib").suppress(BaseException): + await entry.task + service.submit_and_wait(_await_done(), timeout=10.0) + + assert captured, "trapped ainvoke never ran" + # CancelledError takes the early-return branch (line 541-542) — + # the generic-exception status flip never runs. + assert captured[0].status == "running" diff --git a/tests/test_sse_tail_loop.py b/tests/test_sse_tail_loop.py new file mode 100644 index 0000000..8ebd446 --- /dev/null +++ b/tests/test_sse_tail_loop.py @@ -0,0 +1,183 @@ +"""Coverage tests for the SSE ``_stream`` tail loop in +``runtime.api.build_app``. + +The existing ``test_sse_events_replays_backlog`` covers the backlog +drain by forcing ``is_disconnected`` to return True before the tail +loop runs. These tests target the tail-poll branch: + + * one tail iteration delivers a newly-recorded event, then a + second is_disconnected check exits the loop; + * task cancellation surfaces as ``asyncio.CancelledError``, which + the loop must propagate (the bug previously swallowed it via + ``except CancelledError: return``; PR #13 removed the suppressing + handler so cancellation propagates by Python default). +""" +from __future__ import annotations + +import asyncio +import json + +import pytest +from fastapi.testclient import TestClient +from starlette.requests import Request as StarletteRequest + +from runtime.api import build_app +from runtime.config import ( + AppConfig, + LLMConfig, + MCPConfig, + MCPServerConfig, + Paths, + RuntimeConfig, +) + + +@pytest.fixture +def cfg(tmp_path): + return AppConfig( + llm=LLMConfig.stub(), + mcp=MCPConfig(servers=[ + MCPServerConfig(name="local_inc", transport="in_process", + module="examples.incident_management.mcp_server", + category="incident_management"), + MCPServerConfig(name="local_obs", transport="in_process", + module="examples.incident_management.mcp_servers.observability", + category="observability"), + MCPServerConfig(name="local_rem", transport="in_process", + module="examples.incident_management.mcp_servers.remediation", + category="remediation"), + MCPServerConfig(name="local_user", transport="in_process", + module="examples.incident_management.mcp_servers.user_context", + category="user_context"), + ]), + paths=Paths( + skills_dir="examples/incident_management/skills", + incidents_dir=str(tmp_path), + ), + runtime=RuntimeConfig(state_class=None), + ) + + +def _build_request(app, sid: str, *, is_disconnected) -> StarletteRequest: + scope = { + "type": "http", "method": "GET", + "path": f"/sessions/{sid}/events", + "query_string": b"since=0", + "headers": [], + "app": app, + } + request = StarletteRequest(scope) + request.is_disconnected = is_disconnected # type: ignore[method-assign] + return request + + +def _sse_route(app): + return next( + r for r in app.router.routes + if getattr(r, "path", "") == "/sessions/{session_id}/events" + ) + + +@pytest.mark.asyncio +async def test_tail_loop_delivers_post_drain_event(cfg, monkeypatch): + """One backlog event → drain. Then is_disconnected returns False + once (tail iterates), record a new event, second is_disconnected + returns True (exit). Verifies the tail-loop body (lines 880-888) + runs and yields the new frame. + """ + # Cut sleep duration so the test stays fast. Capture the original + # before patching so the replacement doesn't recurse into itself. + real_sleep = asyncio.sleep + + async def _instant_sleep(_seconds: float) -> None: + await real_sleep(0) + + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + app = build_app(cfg) + with TestClient(app): + orch = app.state.orchestrator + orch.event_log.record("SES-TAIL", "agent_started", agent="triage") + + disconnect_results = iter([False, True]) + + async def _is_disconnected() -> bool: + try: + v = next(disconnect_results) + except StopIteration: + return True + # First "False" return: record a NEW event so the next + # iter_for(since=last_seq) call inside the tail body + # yields something. This proves the tail loop ran. + if v is False: + orch.event_log.record( + "SES-TAIL", "tool_invoked", + tool="post-drain", agent="triage", + ) + return v + + request = _build_request(app, "SES-TAIL", is_disconnected=_is_disconnected) + response = await _sse_route(app).endpoint( + session_id="SES-TAIL", request=request, since=0, + ) + + frames: list[dict] = [] + async for chunk in response.body_iterator: + text = chunk.decode() if isinstance(chunk, bytes) else chunk + for line in text.splitlines(): + if line.startswith("data: "): + frames.append(json.loads(line[len("data: "):])) + + kinds = [f["kind"] for f in frames] + # The backlog drain delivered the first event; the tail loop + # delivered the post-drain event added inside _is_disconnected. + assert "agent_started" in kinds + assert "tool_invoked" in kinds, ( + f"tail loop never yielded a post-drain frame; got kinds={kinds}" + ) + + +@pytest.mark.asyncio +async def test_tail_loop_propagates_cancellation(cfg, monkeypatch): + """Cancelling the SSE generator must raise ``CancelledError`` out + of the ``_stream`` body — not be swallowed. + + This pins the PR #13 bug fix: the original code swallowed + ``CancelledError`` via ``except: return``; the post-fix loop has + no ``except`` wrapper at all, so cancellation propagates by + Python's default. Previously this test would observe + ``StopAsyncIteration`` instead of ``CancelledError``. + """ + # Force the sleep inside the tail loop to raise CancelledError on + # first call so we can observe propagation without racing on real + # cancellation timing. + async def _cancel_immediately(*_a, **_k): + raise asyncio.CancelledError() + + monkeypatch.setattr("asyncio.sleep", _cancel_immediately) + + app = build_app(cfg) + with TestClient(app): + orch = app.state.orchestrator + orch.event_log.record("SES-CANCEL", "agent_started", agent="triage") + + async def _never_disconnect() -> bool: + return False + + request = _build_request(app, "SES-CANCEL", is_disconnected=_never_disconnect) + response = await _sse_route(app).endpoint( + session_id="SES-CANCEL", request=request, since=0, + ) + + frames: list[dict] = [] + with pytest.raises(asyncio.CancelledError): + async for chunk in response.body_iterator: + text = chunk.decode() if isinstance(chunk, bytes) else chunk + for line in text.splitlines(): + if line.startswith("data: "): + frames.append(json.loads(line[len("data: "):])) + + # The backlog drain delivered the seeded event before the tail + # loop's first sleep raised — proves CancelledError propagated + # naturally out of the generator (not swallowed). + assert any(f["kind"] == "agent_started" for f in frames)