From 2535bf41dc804902b20946bb9411860b855d7473 Mon Sep 17 00:00:00 2001 From: David Hyrule Date: Mon, 29 Jun 2026 13:07:59 +0200 Subject: [PATCH] feat(collector): minimal FastAPI trace collector (agent_core.collector) POST /v1/trace + /v1/trace/batch (validated by the TraceEvent contract), GET /v1/trace, /healthz; async SQLAlchemy store (sqlite for tests, Postgres via HYRULE_COLLECTOR_DATABASE_URL in prod). Optional 'collector' extra so core contracts stay dependency-light; run with 'python -m agent_core.collector'. ruff + mypy clean; 36 tests pass. Co-Authored-By: Claude Opus 4.8 --- agent_core/collector/__init__.py | 7 +++ agent_core/collector/__main__.py | 21 ++++++++ agent_core/collector/app.py | 89 +++++++++++++++++++++++++++++++ agent_core/collector/db.py | 68 +++++++++++++++++++++++ pyproject.toml | 6 ++- tests/collector/test_collector.py | 47 ++++++++++++++++ 6 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 agent_core/collector/__init__.py create mode 100644 agent_core/collector/__main__.py create mode 100644 agent_core/collector/app.py create mode 100644 agent_core/collector/db.py create mode 100644 tests/collector/test_collector.py diff --git a/agent_core/collector/__init__.py b/agent_core/collector/__init__.py new file mode 100644 index 0000000..0b3515a --- /dev/null +++ b/agent_core/collector/__init__.py @@ -0,0 +1,7 @@ +"""Optional FastAPI trace collector. Install the ``collector`` extra to use it.""" + +from __future__ import annotations + +from agent_core.collector.app import create_app + +__all__ = ["create_app"] diff --git a/agent_core/collector/__main__.py b/agent_core/collector/__main__.py new file mode 100644 index 0000000..1f1fa53 --- /dev/null +++ b/agent_core/collector/__main__.py @@ -0,0 +1,21 @@ +"""Run the collector: ``python -m agent_core.collector``.""" + +from __future__ import annotations + +import os + +import uvicorn + +from agent_core.collector.app import create_app + + +def main() -> None: + uvicorn.run( + create_app(), + host=os.environ.get("HYRULE_COLLECTOR_BIND", "127.0.0.1"), + port=int(os.environ.get("HYRULE_COLLECTOR_PORT", "8770")), + ) + + +if __name__ == "__main__": + main() diff --git a/agent_core/collector/app.py b/agent_core/collector/app.py new file mode 100644 index 0000000..6776b05 --- /dev/null +++ b/agent_core/collector/app.py @@ -0,0 +1,89 @@ +"""Minimal FastAPI trace collector. + +Accepts agent-core ``TraceEvent`` payloads (validated by the shared contract) and stores +them in Postgres/sqlite. Part of the optional ``collector`` extra. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +from fastapi import FastAPI +from sqlalchemy import select + +from agent_core.collector.db import ( + TraceEventRow, + init_models, + make_engine, + make_sessionmaker, +) +from agent_core.contracts._base import utcnow +from agent_core.contracts.tracing import TraceEvent + + +def _row_from_event(event: TraceEvent) -> TraceEventRow: + cost = event.cost + return TraceEventRow( + event_id=event.event_id, + event_type=event.event_type, + run_id=event.run_id, + trace_id=event.trace_id, + graph_id=event.graph_id, + graph_version=event.graph_version, + node_id=event.node_id, + agent_role=event.agent_role, + environment=event.environment, + summary=event.summary, + model=cost.model if cost else None, + provider=cost.provider if cost else None, + cost_usd=cost.usd if cost else None, + input_tokens=cost.input_tokens if cost else None, + output_tokens=cost.output_tokens if cost else None, + timestamp=event.timestamp, + received_at=utcnow(), + event=event.model_dump(mode="json"), + ) + + +def create_app(database_url: str | None = None) -> FastAPI: + engine = make_engine(database_url) + sessionmaker = make_sessionmaker(engine) + + @asynccontextmanager + async def lifespan(_: FastAPI) -> AsyncIterator[None]: + await init_models(engine) + yield + await engine.dispose() + + app = FastAPI(title="agent-core trace collector", version="0.2.0", lifespan=lifespan) + + @app.get("/healthz") + async def healthz() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/v1/trace") + async def ingest(event: TraceEvent) -> dict[str, str]: + async with sessionmaker() as session: + session.add(_row_from_event(event)) + await session.commit() + return {"status": "stored", "event_id": event.event_id} + + @app.post("/v1/trace/batch") + async def ingest_batch(events: list[TraceEvent]) -> dict[str, int]: + async with sessionmaker() as session: + session.add_all([_row_from_event(event) for event in events]) + await session.commit() + return {"stored": len(events)} + + @app.get("/v1/trace") + async def recent(run_id: str | None = None, limit: int = 50) -> list[dict[str, Any]]: + stmt = select(TraceEventRow).order_by(TraceEventRow.id.desc()).limit(min(limit, 500)) + if run_id: + stmt = stmt.where(TraceEventRow.run_id == run_id) + async with sessionmaker() as session: + rows = (await session.execute(stmt)).scalars().all() + return [row.event for row in rows] + + return app diff --git a/agent_core/collector/db.py b/agent_core/collector/db.py new file mode 100644 index 0000000..6255f25 --- /dev/null +++ b/agent_core/collector/db.py @@ -0,0 +1,68 @@ +"""Async SQLAlchemy store for collected TraceEvents. + +Part of the optional ``collector`` extra (not imported by ``agent_core.contracts``). +Defaults to a local sqlite file; set ``HYRULE_COLLECTOR_DATABASE_URL`` to a Postgres +async URL (``postgresql+asyncpg://...``) in production. +""" + +from __future__ import annotations + +import os +from datetime import datetime +from typing import Any + +from sqlalchemy import JSON, DateTime, Float, Integer, String +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +DEFAULT_DATABASE_URL = "sqlite+aiosqlite:///./collector.db" + + +def database_url() -> str: + return os.environ.get("HYRULE_COLLECTOR_DATABASE_URL", DEFAULT_DATABASE_URL) + + +class Base(DeclarativeBase): + pass + + +class TraceEventRow(Base): + __tablename__ = "trace_events" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + event_id: Mapped[str] = mapped_column(String(80), index=True) + event_type: Mapped[str] = mapped_column(String(64), index=True) + run_id: Mapped[str | None] = mapped_column(String(128), index=True, default=None) + trace_id: Mapped[str | None] = mapped_column(String(128), index=True, default=None) + graph_id: Mapped[str | None] = mapped_column(String(128), index=True, default=None) + graph_version: Mapped[str | None] = mapped_column(String(128), default=None) + node_id: Mapped[str | None] = mapped_column(String(128), default=None) + agent_role: Mapped[str | None] = mapped_column(String(128), default=None) + environment: Mapped[str | None] = mapped_column(String(64), default=None) + summary: Mapped[str] = mapped_column(String, default="") + model: Mapped[str | None] = mapped_column(String(128), default=None) + provider: Mapped[str | None] = mapped_column(String(64), default=None) + cost_usd: Mapped[float | None] = mapped_column(Float, default=None) + input_tokens: Mapped[int | None] = mapped_column(Integer, default=None) + output_tokens: Mapped[int | None] = mapped_column(Integer, default=None) + timestamp: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), default=None) + received_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + event: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) + + +def make_engine(url: str | None = None) -> AsyncEngine: + return create_async_engine(url or database_url(), future=True) + + +def make_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + return async_sessionmaker(engine, expire_on_commit=False) + + +async def init_models(engine: AsyncEngine) -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) diff --git a/pyproject.toml b/pyproject.toml index da05b2d..0e4fd56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,11 @@ requires-python = ">=3.11" dependencies = ["pydantic>=2,<3"] [project.optional-dependencies] -dev = ["pytest>=8", "pyyaml>=6", "ruff>=0.5", "mypy>=1.8", "types-PyYAML"] +collector = ["fastapi>=0.110", "uvicorn>=0.29", "sqlalchemy>=2", "asyncpg>=0.29"] +dev = [ + "pytest>=8", "pyyaml>=6", "ruff>=0.5", "mypy>=1.8", "types-PyYAML", + "fastapi>=0.110", "uvicorn>=0.29", "sqlalchemy>=2", "httpx>=0.27", "aiosqlite>=0.20", +] [tool.hatch.build.targets.wheel] packages = ["agent_core"] diff --git a/tests/collector/test_collector.py b/tests/collector/test_collector.py new file mode 100644 index 0000000..6b4aacc --- /dev/null +++ b/tests/collector/test_collector.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from fastapi.testclient import TestClient + +from agent_core.collector.app import create_app + + +def _app(tmp_path): + return create_app(f"sqlite+aiosqlite:///{tmp_path}/collector.db") + + +def test_healthz(tmp_path) -> None: + with TestClient(_app(tmp_path)) as client: + assert client.get("/healthz").json() == {"status": "ok"} + + +def test_ingest_and_read(tmp_path) -> None: + event = { + "event_type": "model_call", + "summary": "hello", + "run_id": "r1", + "cost": {"usd": 0.02, "input_tokens": 10, "output_tokens": 5}, + } + with TestClient(_app(tmp_path)) as client: + resp = client.post("/v1/trace", json=event) + assert resp.status_code == 200 + assert resp.json()["status"] == "stored" + got = client.get("/v1/trace", params={"run_id": "r1"}).json() + assert len(got) == 1 + assert got[0]["event_type"] == "model_call" + assert got[0]["cost"]["usd"] == 0.02 + + +def test_ingest_rejects_invalid(tmp_path) -> None: + with TestClient(_app(tmp_path)) as client: + # missing required event_type + assert client.post("/v1/trace", json={"summary": "x"}).status_code == 422 + + +def test_batch(tmp_path) -> None: + events = [ + {"event_type": "tool_call", "run_id": "r2"}, + {"event_type": "node_end", "run_id": "r2"}, + ] + with TestClient(_app(tmp_path)) as client: + assert client.post("/v1/trace/batch", json=events).json() == {"stored": 2} + assert len(client.get("/v1/trace", params={"run_id": "r2"}).json()) == 2