Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backend/app/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,39 @@ async def emit_step_failed(
"step.failed", {"index": index, "node": node, "error": error, **data}
)

async def emit_step_updated(
self,
*,
index: int,
tokens_in: int | None = None,
tokens_out: int | None = None,
latency_ms: int | None = None,
**data: Any,
) -> None:
"""Flush deferred step metrics (tokens, latency) without completing the step."""
payload: dict[str, Any] = {"index": index, **data}
if tokens_in is not None:
payload["tokens_in"] = tokens_in
if tokens_out is not None:
payload["tokens_out"] = tokens_out
if latency_ms is not None:
payload["latency_ms"] = latency_ms
await self.emit("step.updated", payload)

async def emit_token_delta(
self,
*,
step_index: int,
delta: str,
role: str = "assistant",
**data: Any,
) -> None:
"""Stream an incremental token chunk to SSE subscribers (no DB write)."""
await self.emit(
"token.delta",
{"step_index": step_index, "delta": delta, "role": role, **data},
)

async def emit_message(
self, *, role: str, content: str, name: str | None = None, **extra: Any
) -> None:
Expand Down
177 changes: 164 additions & 13 deletions backend/app/adapters/langgraph_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
{
"model": "openai/gpt-4o-mini",
"system_prompt": "You are a helpful coordinator.",
"stream_tokens": true, // emit token.delta SSE; defer tokens to step.updated
"tools": ["echo"], // tool registry keys
"graph": {
"nodes": [
Expand All @@ -35,6 +36,8 @@

from __future__ import annotations

import json
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -300,16 +303,35 @@ async def handler(state: dict[str, Any]) -> dict[str, Any]:
await ctx.emit_message(role="system", content=system_prompt)
await ctx.emit_message(role="user", content=user_input)

reply = await self._invoke_model(
stream_tokens = _coerce_bool(
run_state.config.get("stream_tokens"), default=True
)
started = time.monotonic()
reply, tokens_in, tokens_out = await self._invoke_model(
ctx,
step_idx,
model,
system_prompt,
user_input,
tool_keys=tool_keys if tool_keys else None,
stream_tokens=stream_tokens,
)
latency_ms = int((time.monotonic() - started) * 1000)

await ctx.emit_step_updated(
index=step_idx,
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
)
await ctx.emit_message(role="assistant", content=reply)
await ctx.emit_step_completed(
index=step_idx, node=spec.id, output={"reply": reply}
index=step_idx,
node=spec.id,
output={"reply": reply},
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
)
messages = list(state.get("messages") or [])
messages.extend(
Expand All @@ -325,24 +347,31 @@ async def handler(state: dict[str, Any]) -> dict[str, Any]:

async def _invoke_model(
self,
ctx: AdapterContext,
step_index: int,
model: str,
system_prompt: str,
user_input: str,
*,
tool_keys: list[str] | None = None,
) -> str:
"""Call the configured chat model.
stream_tokens: bool = True,
) -> tuple[str, int, int]:
"""Call the configured chat model, optionally streaming token deltas.

The MVP routes everything through an OpenAI-compatible Chat
Completions endpoint. When ``tool_keys`` are set, tool schemas are
attached so the model may request function calls (single round-trip).
Returns ``(reply, tokens_in, tokens_out)``. Token counts are taken from
provider usage when available; otherwise estimated from text length.
"""
settings = get_settings()
if not settings.openai_api_key:
suffix = ""
if tool_keys:
suffix = f" [tools={','.join(tool_keys)}]"
return f"[mock:{model}]{suffix} {user_input}"
return await self._invoke_mock(
ctx,
step_index,
model,
system_prompt,
user_input,
tool_keys=tool_keys,
stream_tokens=stream_tokens,
)

import httpx

Expand All @@ -356,6 +385,11 @@ async def _invoke_model(
if tool_keys:
payload["tools"] = tool_schemas(tool_keys)

if stream_tokens and not tool_keys:
return await self._invoke_openai_streaming(
ctx, step_index, settings, payload
)

async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{settings.openai_base_url}/chat/completions",
Expand All @@ -368,5 +402,122 @@ async def _invoke_model(
tool_calls = message.get("tool_calls")
if tool_calls:
names = [tc["function"]["name"] for tc in tool_calls]
return f"[tool_calls:{','.join(names)}]"
return message.get("content") or ""
reply = f"[tool_calls:{','.join(names)}]"
else:
reply = message.get("content") or ""
usage = data.get("usage") or {}
tokens_in = int(usage.get("prompt_tokens") or _estimate_tokens(system_prompt + user_input))
tokens_out = int(usage.get("completion_tokens") or _estimate_tokens(reply))
return reply, tokens_in, tokens_out

async def _invoke_mock(
self,
ctx: AdapterContext,
step_index: int,
model: str,
system_prompt: str,
user_input: str,
*,
tool_keys: list[str] | None = None,
stream_tokens: bool = True,
) -> tuple[str, int, int]:
suffix = ""
if tool_keys:
suffix = f" [tools={','.join(tool_keys)}]"
reply = f"[mock:{model}]{suffix} {user_input}"
tokens_in = _estimate_tokens(system_prompt + user_input)
tokens_out = _estimate_tokens(reply)
if stream_tokens:
for chunk in _chunk_text(reply):
await ctx.emit_token_delta(step_index=step_index, delta=chunk)
return reply, tokens_in, tokens_out

async def _invoke_openai_streaming(
self,
ctx: AdapterContext,
step_index: int,
settings: Any,
payload: dict[str, Any],
) -> tuple[str, int, int]:
import httpx

payload = {
**payload,
"stream": True,
"stream_options": {"include_usage": True},
}
parts: list[str] = []
tokens_in = 0
tokens_out = 0

async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream(
"POST",
f"{settings.openai_base_url}/chat/completions",
headers={"Authorization": f"Bearer {settings.openai_api_key}"},
json=payload,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
usage = chunk.get("usage")
if usage:
tokens_in = int(usage.get("prompt_tokens") or tokens_in)
tokens_out = int(
usage.get("completion_tokens") or tokens_out
)
for choice in chunk.get("choices") or []:
delta = choice.get("delta") or {}
content = delta.get("content")
if content:
parts.append(content)
await ctx.emit_token_delta(
step_index=step_index, delta=content
)

reply = "".join(parts)
if not tokens_in:
tokens_in = _estimate_tokens(
str(payload.get("messages", ""))
)
if not tokens_out:
tokens_out = _estimate_tokens(reply)
return reply, tokens_in, tokens_out


def _estimate_tokens(text: str) -> int:
"""Rough token estimate when the provider omits usage (≈4 chars/token)."""
if not text:
return 0
return max(1, len(text) // 4)


def _coerce_bool(value: Any, *, default: bool) -> bool:
"""Parse boolean-ish config values without treating arbitrary strings as truthy."""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
if isinstance(value, int):
return bool(value)
return default


def _chunk_text(text: str, *, size: int = 8) -> list[str]:
"""Split text into small chunks for mock streaming."""
return [text[i : i + size] for i in range(0, len(text), size)] or [text]
2 changes: 2 additions & 0 deletions backend/app/schemas/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ class RunRead(BaseModel):
"run.failed",
"run.cancelled",
"step.started",
"step.updated",
"step.completed",
"step.failed",
"token.delta",
"message.created",
"tool_call.started",
"tool_call.completed",
Expand Down
14 changes: 14 additions & 0 deletions backend/app/services/run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ async def _handle_event(
)
self.session.add(step)
await self.session.commit()
elif event_type == "step.updated":
step = await self._find_step(run_id, data["index"])
if step is not None:
if "latency_ms" in data:
step.latency_ms = data["latency_ms"]
if "tokens_in" in data:
step.tokens_in = data["tokens_in"]
if "tokens_out" in data:
step.tokens_out = data["tokens_out"]
await self.session.commit()
elif event_type == "step.completed":
step = await self._find_step(run_id, data["index"])
if step is not None:
Expand All @@ -238,6 +248,10 @@ async def _handle_event(
step.tokens_in = data.get("tokens_in")
step.tokens_out = data.get("tokens_out")
await self.session.commit()
elif event_type == "token.delta":
# SSE-only: avoid per-chunk DB commits during streaming.
await self._broadcast(event_type, run_id, data)
return
elif event_type == "step.failed":
step = await self._find_step(run_id, data["index"])
if step is not None:
Expand Down
67 changes: 66 additions & 1 deletion backend/tests/test_langgraph_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from app.adapters.base import AdapterContext, AdapterResult
from app.adapters.base import AdapterContext
from app.adapters.langgraph_adapter import GraphSpec, LangGraphAdapter
from app.adapters.tool_registry import get_tool, list_tools, register_tool, resolve_tools
from app.models.run import RunStatus
Expand Down Expand Up @@ -125,3 +125,68 @@ async def test_langgraph_unknown_tool_fails():
result = await adapter.run(ctx)
assert result.status == RunStatus.FAILED
assert "Unknown tool" in (result.error or "")


@pytest.mark.asyncio
async def test_langgraph_streams_token_deltas_and_defers_step_tokens():
adapter = LangGraphAdapter()
ctx = _RecordingContext()
ctx.agent_config = {"model": "openai/gpt-4o-mini", "stream_tokens": True}
result = await adapter.run(ctx)
assert result.status == RunStatus.SUCCEEDED

token_events = [
data for event, data in ctx.events if event == "token.delta"
]
assert len(token_events) > 0
assert all(e["step_index"] == 0 for e in token_events)
reply = (result.output or {}).get("reply", "")
assert "".join(e["delta"] for e in token_events) == reply

updated = [
data for event, data in ctx.events if event == "step.updated"
]
assert len(updated) == 1
assert updated[0]["tokens_in"] > 0
assert updated[0]["tokens_out"] > 0
assert updated[0]["latency_ms"] >= 0

completed = [
data
for event, data in ctx.events
if event == "step.completed" and data["node"] == "call_model"
]
assert completed[0]["tokens_in"] == updated[0]["tokens_in"]
assert completed[0]["tokens_out"] == updated[0]["tokens_out"]


@pytest.mark.asyncio
async def test_langgraph_stream_tokens_disabled():
adapter = LangGraphAdapter()
ctx = _RecordingContext()
ctx.agent_config = {
"model": "openai/gpt-4o-mini",
"stream_tokens": False,
}
await adapter.run(ctx)
assert not any(event == "token.delta" for event, _ in ctx.events)


@pytest.mark.asyncio
@pytest.mark.parametrize(
("raw_value", "should_stream"),
[("false", False), ("true", True)],
)
async def test_langgraph_stream_tokens_string_values(
raw_value: str, should_stream: bool
):
adapter = LangGraphAdapter()
ctx = _RecordingContext()
ctx.agent_config = {
"model": "openai/gpt-4o-mini",
"stream_tokens": raw_value,
}

await adapter.run(ctx)

assert any(event == "token.delta" for event, _ in ctx.events) is should_stream
Loading