diff --git a/src/lorebinders/agent/factory.py b/src/lorebinders/agent/factory.py index 9948a59..13df9b7 100644 --- a/src/lorebinders/agent/factory.py +++ b/src/lorebinders/agent/factory.py @@ -125,6 +125,22 @@ async def run_agent_async( f"Agent run completed with model {model}", meta, ) + try: + usage = res.usage() + emit_observation( + on_observe, + ObservationType.METRIC, + "agent", + f"Token usage for model {model}", + { + "model": model, + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.input_tokens + usage.output_tokens, + }, + ) + except Exception as e: + logger.warning(f"Failed to collect token usage metrics: {e}") return res.output except Exception as e: logger.error(f"Agent run failed: {e}") diff --git a/tests/unit/agents/test_factory.py b/tests/unit/agents/test_factory.py index e4b5473..2622fd7 100644 --- a/tests/unit/agents/test_factory.py +++ b/tests/unit/agents/test_factory.py @@ -1,9 +1,22 @@ +import pytest from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.test import TestModel -from lorebinders.agent.factory import _is_moderation_error, create_agent -from lorebinders.models import AgentDeps, ExtractionResult +from lorebinders.agent.factory import ( + _is_moderation_error, + create_agent, + create_extraction_agent, + load_prompt_from_assets, + run_agent_async, +) +from lorebinders.models import ( + AgentDeps, + ExtractionResult, + ObservationEvent, + ObservationType, +) +from lorebinders.settings import get_settings def test_is_moderation_error_true() -> None: @@ -40,3 +53,38 @@ def test_create_agent_with_fallback_wraps_in_fallback_model() -> None: fallback=fallback, ) assert isinstance(agent.model, FallbackModel) + + +@pytest.mark.anyio +async def test_run_agent_async_emits_metric_event() -> None: + """Test that run_agent_async emits a METRIC event with token counts.""" + observations: list[ObservationEvent] = [] + + def on_observe(event: ObservationEvent) -> None: + observations.append(event) + + agent = create_extraction_agent() + agent.model = TestModel() + + deps = AgentDeps( + settings=get_settings(), + prompt_loader=load_prompt_from_assets, + ) + + await run_agent_async( + agent, "test prompt", deps=deps, on_observe=on_observe + ) + + metric_events = [ + o for o in observations if o.type == ObservationType.METRIC + ] + assert len(metric_events) == 1 + meta = metric_events[0].metadata + assert isinstance(meta["input_tokens"], int) + assert isinstance(meta["output_tokens"], int) + assert isinstance(meta["total_tokens"], int) + assert meta["input_tokens"] >= 0 + assert meta["output_tokens"] >= 0 + assert meta["total_tokens"] >= 0 + assert meta["total_tokens"] == meta["input_tokens"] + meta["output_tokens"] + assert "model" in meta