Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/lorebinders/agent/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ async def run_agent_async(
f"Agent run completed with model {model}",
meta,
)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
try:
usage = res.usage()
emit_observation(
on_observe,
ObservationType.METRIC,
"agent",
f"Token usage for model {model}",
{
"model": model,
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.input_tokens + usage.output_tokens,
},
)
except Exception as e:
logger.warning(f"Failed to collect token usage metrics: {e}")
return res.output
except Exception as e:
logger.error(f"Agent run failed: {e}")
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/agents/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import pytest
from pydantic_ai.exceptions import ModelHTTPError
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.test import TestModel

from lorebinders.agent.factory import _is_moderation_error, create_agent
from lorebinders.models import AgentDeps, ExtractionResult
from lorebinders.agent.factory import (
_is_moderation_error,
create_agent,
create_extraction_agent,
load_prompt_from_assets,
run_agent_async,
)
from lorebinders.models import (
AgentDeps,
ExtractionResult,
ObservationEvent,
ObservationType,
)
from lorebinders.settings import get_settings


def test_is_moderation_error_true() -> None:
Expand Down Expand Up @@ -40,3 +53,38 @@ def test_create_agent_with_fallback_wraps_in_fallback_model() -> None:
fallback=fallback,
)
assert isinstance(agent.model, FallbackModel)


@pytest.mark.anyio
async def test_run_agent_async_emits_metric_event() -> None:
"""Test that run_agent_async emits a METRIC event with token counts."""
observations: list[ObservationEvent] = []

def on_observe(event: ObservationEvent) -> None:
observations.append(event)

agent = create_extraction_agent()
agent.model = TestModel()

deps = AgentDeps(
settings=get_settings(),
prompt_loader=load_prompt_from_assets,
)

await run_agent_async(
agent, "test prompt", deps=deps, on_observe=on_observe
)

metric_events = [
o for o in observations if o.type == ObservationType.METRIC
]
assert len(metric_events) == 1
meta = metric_events[0].metadata
assert isinstance(meta["input_tokens"], int)
assert isinstance(meta["output_tokens"], int)
assert isinstance(meta["total_tokens"], int)
assert meta["input_tokens"] >= 0
assert meta["output_tokens"] >= 0
assert meta["total_tokens"] >= 0
assert meta["total_tokens"] == meta["input_tokens"] + meta["output_tokens"]
Comment on lines +82 to +89
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen assertions on token values by checking for non-negative counts.

Currently the test only checks that token counts are integers and that total_tokens equals the sum of input and output tokens. Please also assert that input_tokens, output_tokens, and total_tokens are >= 0 so tests fail if negative token values are ever emitted.

assert "model" in meta