diff --git a/backend-java/src/main/java/io/agentflow/api/config/AgentflowProperties.java b/backend-java/src/main/java/io/agentflow/api/config/AgentflowProperties.java index 2908ad5..fa29083 100644 --- a/backend-java/src/main/java/io/agentflow/api/config/AgentflowProperties.java +++ b/backend-java/src/main/java/io/agentflow/api/config/AgentflowProperties.java @@ -135,6 +135,8 @@ public void setTtlSeconds(long ttlSeconds) { public static class Events { private String channelPrefix = "agentflow:run:"; + private String streamSuffix = ":log"; + private int streamMaxLen = 10_000; private long sseHeartbeatSeconds = 15; public String getChannelPrefix() { @@ -145,6 +147,22 @@ public void setChannelPrefix(String channelPrefix) { this.channelPrefix = channelPrefix; } + public String getStreamSuffix() { + return streamSuffix; + } + + public void setStreamSuffix(String streamSuffix) { + this.streamSuffix = streamSuffix; + } + + public int getStreamMaxLen() { + return streamMaxLen; + } + + public void setStreamMaxLen(int streamMaxLen) { + this.streamMaxLen = streamMaxLen; + } + public long getSseHeartbeatSeconds() { return sseHeartbeatSeconds; } diff --git a/backend-java/src/main/java/io/agentflow/api/controller/EventsController.java b/backend-java/src/main/java/io/agentflow/api/controller/EventsController.java index 8bbb5ce..98799f7 100644 --- a/backend-java/src/main/java/io/agentflow/api/controller/EventsController.java +++ b/backend-java/src/main/java/io/agentflow/api/controller/EventsController.java @@ -4,7 +4,9 @@ import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -19,7 +21,12 @@ public EventsController(EventStreamService events) { } @GetMapping(value = "/{runId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter stream(@PathVariable String runId) { - return events.subscribe(runId); + public SseEmitter stream( + @PathVariable String runId, + @RequestHeader(value = "Last-Event-ID", required = false) String lastEventId, + @RequestParam(value = "last_event_id", required = false) String lastEventIdParam) { + String after = + lastEventId != null && !lastEventId.isBlank() ? lastEventId : lastEventIdParam; + return events.subscribe(runId, after); } } diff --git a/backend-java/src/main/java/io/agentflow/api/service/EventStreamService.java b/backend-java/src/main/java/io/agentflow/api/service/EventStreamService.java index 1498908..2b86ec9 100644 --- a/backend-java/src/main/java/io/agentflow/api/service/EventStreamService.java +++ b/backend-java/src/main/java/io/agentflow/api/service/EventStreamService.java @@ -1,10 +1,11 @@ package io.agentflow.api.service; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import io.agentflow.api.config.AgentflowProperties; import io.agentflow.api.dto.RunEvent; import java.io.IOException; -import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -12,9 +13,12 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.data.redis.connection.MessageListener; +import org.springframework.data.domain.Range; +import org.springframework.data.redis.connection.stream.MapRecord; +import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.listener.ChannelTopic; import org.springframework.data.redis.listener.RedisMessageListenerContainer; import org.springframework.stereotype.Service; @@ -22,9 +26,11 @@ /** * Bridges the worker's Redis pub/sub stream to per-run SSE emitters. Each - * subscriber gets its own {@link MessageListener} on the channel; lifecycle - * is tied to the {@link SseEmitter} (completion/timeout/error all remove - * the listener). + * subscriber gets its own {@link org.springframework.data.redis.connection.MessageListener} + * on the channel; lifecycle is tied to the {@link SseEmitter}. + * + *

Events are also persisted in a Redis Stream per run so clients can + * reconnect with {@code Last-Event-ID} and receive any missed frames. */ @Service public class EventStreamService { @@ -34,6 +40,7 @@ public class EventStreamService { Set.of("run.completed", "run.failed", "run.cancelled"); private final RedisMessageListenerContainer listenerContainer; + private final StringRedisTemplate redis; private final ObjectMapper mapper; private final AgentflowProperties props; private final ScheduledExecutorService heartbeatExecutor = @@ -43,30 +50,55 @@ public class EventStreamService { t.setDaemon(true); return t; }); - private final Map active = new ConcurrentHashMap<>(); + private final Map active = + new ConcurrentHashMap<>(); public EventStreamService( RedisMessageListenerContainer listenerContainer, + StringRedisTemplate redis, ObjectMapper mapper, AgentflowProperties props) { this.listenerContainer = listenerContainer; + this.redis = redis; this.mapper = mapper; this.props = props; } public SseEmitter subscribe(String runId) { + return subscribe(runId, null); + } + + public SseEmitter subscribe(String runId, String afterEventId) { SseEmitter emitter = new SseEmitter(0L); ChannelTopic topic = new ChannelTopic(props.getEvents().getChannelPrefix() + runId); + String streamKey = streamKey(runId); + AtomicReference lastSentId = new AtomicReference<>(afterEventId); + + try { + if (!replay(emitter, streamKey, afterEventId, lastSentId)) { + return emitter; + } + } catch (Exception e) { + log.warn("Failed to replay SSE events for run {}", runId, e); + emitter.completeWithError(e); + return emitter; + } - MessageListener listener = (message, pattern) -> { + org.springframework.data.redis.connection.MessageListener listener = (message, pattern) -> { try { - String body = new String(message.getBody()); - RunEvent event = mapper.readValue(body, RunEvent.class); - emitter.send( - SseEmitter.event() - .name(event.type()) - .data(body)); - if (TERMINAL_TYPES.contains(event.type())) { + DeliveredEvent delivered = parseEnvelope(new String(message.getBody())); + if (delivered.eventId() != null + && lastSentId.get() != null + && !isAfter(delivered.eventId(), lastSentId.get())) { + return; + } + if (!send(emitter, delivered)) { + return; + } + if (delivered.eventId() != null) { + lastSentId.set(delivered.eventId()); + } + if (TERMINAL_TYPES.contains(delivered.event().type())) { emitter.complete(); } } catch (IllegalStateException ignored) { @@ -82,6 +114,12 @@ public SseEmitter subscribe(String runId) { listenerContainer.addMessageListener(listener, topic); active.put(emitter, listener); + try { + replay(emitter, streamKey, lastSentId.get(), lastSentId); + } catch (Exception e) { + log.warn("Failed catch-up replay for run {}", runId, e); + } + long heartbeatSeconds = Math.max(1, props.getEvents().getSseHeartbeatSeconds()); ScheduledFuture heartbeat = heartbeatExecutor.scheduleAtFixedRate( () -> { @@ -97,7 +135,7 @@ public SseEmitter subscribe(String runId) { Runnable cleanup = () -> { heartbeat.cancel(true); - MessageListener removed = active.remove(emitter); + org.springframework.data.redis.connection.MessageListener removed = active.remove(emitter); if (removed != null) { listenerContainer.removeMessageListener(removed, topic); } @@ -108,4 +146,76 @@ public SseEmitter subscribe(String runId) { return emitter; } + + private String streamKey(String runId) { + return props.getEvents().getChannelPrefix() + runId + props.getEvents().getStreamSuffix(); + } + + /** @return {@code false} when a terminal event was replayed. */ + private boolean replay( + SseEmitter emitter, + String streamKey, + String afterEventId, + AtomicReference lastSentId) + throws IOException { + Range range = + afterEventId == null || afterEventId.isBlank() + ? Range.unbounded() + : Range.of(Range.Bound.exclusive(afterEventId), Range.Bound.unbounded()); + + List> records = + redis.opsForStream().range(streamKey, range); + if (records == null) { + return true; + } + + for (MapRecord record : records) { + Object payload = record.getValue().get("payload"); + if (payload == null) { + continue; + } + RunEvent event = mapper.readValue(payload.toString(), RunEvent.class); + DeliveredEvent delivered = new DeliveredEvent(record.getId().getValue(), event); + if (!send(emitter, delivered)) { + return false; + } + lastSentId.set(delivered.eventId()); + if (TERMINAL_TYPES.contains(event.type())) { + emitter.complete(); + return false; + } + } + return true; + } + + private boolean send(SseEmitter emitter, DeliveredEvent delivered) throws IOException { + String body = mapper.writeValueAsString(delivered.event()); + SseEmitter.SseEventBuilder builder = + SseEmitter.event().name(delivered.event().type()).data(body); + if (delivered.eventId() != null) { + builder.id(delivered.eventId()); + } + emitter.send(builder); + return true; + } + + private DeliveredEvent parseEnvelope(String body) throws IOException { + JsonNode root = mapper.readTree(body); + if (root.has("id") && root.has("event")) { + String eventId = root.get("id").asText(); + RunEvent event = mapper.treeToValue(root.get("event"), RunEvent.class); + return new DeliveredEvent(eventId, event); + } + return new DeliveredEvent(null, mapper.readValue(body, RunEvent.class)); + } + + private static boolean isAfter(String candidate, String lastId) { + try { + return Long.parseLong(candidate) > Long.parseLong(lastId); + } catch (NumberFormatException ex) { + return candidate.compareTo(lastId) > 0; + } + } + + private record DeliveredEvent(String eventId, RunEvent event) {} } diff --git a/backend-java/src/main/resources/application.yml b/backend-java/src/main/resources/application.yml index 00c2484..76a051e 100644 --- a/backend-java/src/main/resources/application.yml +++ b/backend-java/src/main/resources/application.yml @@ -36,6 +36,8 @@ agentflow: ttl-seconds: 86400 events: channel-prefix: "agentflow:run:" + stream-suffix: ":log" + stream-max-len: 10000 sse-heartbeat-seconds: 15 otel: enabled: ${AGENTFLOW_OTEL_ENABLED:false} diff --git a/backend/app/api/v1/events.py b/backend/app/api/v1/events.py index b9a3b22..d717185 100644 --- a/backend/app/api/v1/events.py +++ b/backend/app/api/v1/events.py @@ -3,6 +3,10 @@ Clients subscribe with `GET /v1/events/{run_id}`. The stream stays open until the run reaches a terminal state (`succeeded`, `failed`, `cancelled`) or the client disconnects. + +On reconnect, clients may send the standard `Last-Event-ID` header (or the +`last_event_id` query parameter) to receive any events published while +disconnected before resuming the live stream. """ from __future__ import annotations @@ -14,6 +18,7 @@ from sse_starlette.sse import EventSourceResponse from app.events import EventBus, get_event_bus +from app.events.bus import _is_after from app.schemas.run import RunEvent router = APIRouter(prefix="/events", tags=["events"]) @@ -21,28 +26,68 @@ _TERMINAL_TYPES = {"run.completed", "run.failed", "run.cancelled"} +def _resolve_last_event_id(request: Request) -> str | None: + header = request.headers.get("last-event-id") + if header: + return header.strip() + query = request.query_params.get("last_event_id") + if query: + return query.strip() + return None + + +def _sse_frame(event_id: str, event: RunEvent) -> dict[str, str]: + return { + "id": event_id, + "event": event.type, + "data": event.model_dump_json(), + } + + @router.get("/{run_id}") async def stream_run_events( run_id: str, request: Request, bus: EventBus = Depends(get_event_bus), ) -> EventSourceResponse: + after_id = _resolve_last_event_id(request) + async def generator() -> AsyncIterator[dict[str, str]]: + last_id = after_id + + async for event_id, event in bus.replay(run_id, after_id): + if await request.is_disconnected(): + return + yield _sse_frame(event_id, event) + last_id = event_id + if event.type in _TERMINAL_TYPES: + return + async with bus.subscribe(run_id) as queue: + async for event_id, event in bus.replay(run_id, last_id): + if await request.is_disconnected(): + return + yield _sse_frame(event_id, event) + last_id = event_id + if event.type in _TERMINAL_TYPES: + return + while True: if await request.is_disconnected(): break try: - event: RunEvent = await asyncio.wait_for(queue.get(), timeout=15.0) + event_id, event = await asyncio.wait_for(queue.get(), timeout=15.0) except TimeoutError: yield {"event": "ping", "data": "{}"} continue - yield { - "event": event.type, - "data": event.model_dump_json(), - } + if event_id and last_id and not _is_after(event_id, last_id): + continue + if event_id: + last_id = event_id + + yield _sse_frame(event_id or "0", event) if event.type in _TERMINAL_TYPES: break diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 61ec6ac..1c0a175 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -23,6 +23,19 @@ class Settings(BaseSettings): default=None, description="Optional redis URL. When unset, the in-memory event bus is used.", ) + event_channel_prefix: str = Field( + default="agentflow:run:", + description="Redis pub/sub channel prefix for live run events.", + ) + event_stream_suffix: str = Field( + default=":log", + description="Suffix appended to the channel prefix for the replay stream key.", + ) + event_stream_max_len: int = Field( + default=10_000, + ge=100, + description="Approximate max entries kept in each run's Redis event stream.", + ) log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO" default_adapter: str = Field( diff --git a/backend/app/events/bus.py b/backend/app/events/bus.py index 429597a..dd48e6d 100644 --- a/backend/app/events/bus.py +++ b/backend/app/events/bus.py @@ -1,11 +1,11 @@ -"""Per-run pub/sub event bus. +"""Per-run pub/sub event bus with a durable replay log. The bus has two implementations selected at runtime: -* `InMemoryEventBus` – per-process asyncio queues. Suitable for single-node - development and for tests. -* `RedisEventBus` – Redis pub/sub channels keyed by `agentflow:run:{id}`. - Enabled automatically when `AGENTFLOW_REDIS_URL` is configured. +* `InMemoryEventBus` – per-process asyncio queues plus an in-process event + log. Suitable for single-node development and for tests. +* `RedisEventBus` – Redis pub/sub for live delivery and a Redis Stream per + run for `Last-Event-ID` replay. Enabled when `AGENTFLOW_REDIS_URL` is set. Both implementations expose the same async interface and are interchangeable. """ @@ -25,30 +25,59 @@ logger = get_logger("events") +EventRecord = tuple[str, RunEvent] + + +def _is_after(entry_id: str, after_id: str | None) -> bool: + if after_id is None: + return True + try: + return int(entry_id) > int(after_id) + except ValueError: + return entry_id > after_id + class EventBus(Protocol): - async def publish(self, event: RunEvent) -> None: ... + async def publish(self, event: RunEvent) -> str: ... + + def replay( + self, run_id: str, after_id: str | None = None + ) -> AsyncIterator[EventRecord]: ... @asynccontextmanager - def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[RunEvent]]: ... + def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[EventRecord]]: ... async def aclose(self) -> None: ... class InMemoryEventBus: def __init__(self) -> None: - self._subscribers: dict[str, set[asyncio.Queue[RunEvent]]] = defaultdict(set) + self._subscribers: dict[str, set[asyncio.Queue[EventRecord]]] = defaultdict(set) + self._log: dict[str, list[EventRecord]] = defaultdict(list) + self._seq: dict[str, int] = defaultdict(int) self._lock = asyncio.Lock() - async def publish(self, event: RunEvent) -> None: + async def publish(self, event: RunEvent) -> str: async with self._lock: + self._seq[event.run_id] += 1 + event_id = str(self._seq[event.run_id]) + record = (event_id, event) + self._log[event.run_id].append(record) queues = list(self._subscribers.get(event.run_id, ())) for queue in queues: - await queue.put(event) + await queue.put(record) + return event_id + + async def replay( + self, run_id: str, after_id: str | None = None + ) -> AsyncIterator[EventRecord]: + for event_id, event in self._log.get(run_id, ()): + if _is_after(event_id, after_id): + yield event_id, event @asynccontextmanager - async def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[RunEvent]]: - queue: asyncio.Queue[RunEvent] = asyncio.Queue() + async def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[EventRecord]]: + queue: asyncio.Queue[EventRecord] = asyncio.Queue() async with self._lock: self._subscribers[run_id].add(queue) try: @@ -64,31 +93,51 @@ async def aclose(self) -> None: # pragma: no cover - nothing to do class RedisEventBus: - """Redis-backed pub/sub. - - Imported lazily to keep the in-memory path free of redis dependency for - unit tests. - """ + """Redis-backed pub/sub with a per-run stream for replay.""" def __init__(self, url: str) -> None: import redis.asyncio as redis # local import self._redis = redis.from_url(url, decode_responses=True) - - @staticmethod - def _channel(run_id: str) -> str: - return f"agentflow:run:{run_id}" - - async def publish(self, event: RunEvent) -> None: - await self._redis.publish( - self._channel(event.run_id), - event.model_dump_json(), + settings = get_settings() + self._channel_prefix = settings.event_channel_prefix + self._stream_suffix = settings.event_stream_suffix + self._stream_max_len = settings.event_stream_max_len + + def _channel(self, run_id: str) -> str: + return f"{self._channel_prefix}{run_id}" + + def _stream(self, run_id: str) -> str: + return f"{self._channel_prefix}{run_id}{self._stream_suffix}" + + async def publish(self, event: RunEvent) -> str: + stream = self._stream(event.run_id) + event_id = await self._redis.xadd( + stream, + {"payload": event.model_dump_json()}, + maxlen=self._stream_max_len, + approximate=True, ) + envelope = json.dumps({"id": event_id, "event": event.model_dump(mode="json")}) + await self._redis.publish(self._channel(event.run_id), envelope) + return event_id + + async def replay( + self, run_id: str, after_id: str | None = None + ) -> AsyncIterator[EventRecord]: + stream = self._stream(run_id) + start = f"({after_id}" if after_id else "-" + entries = await self._redis.xrange(stream, min=start, max="+") + for entry_id, fields in entries: + try: + yield entry_id, RunEvent(**json.loads(fields["payload"])) + except Exception: # pragma: no cover - defensive + logger.exception("event_replay_decode_failed", run_id=run_id) @asynccontextmanager - async def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[RunEvent]]: + async def subscribe(self, run_id: str) -> AsyncIterator[asyncio.Queue[EventRecord]]: pubsub = self._redis.pubsub() - queue: asyncio.Queue[RunEvent] = asyncio.Queue() + queue: asyncio.Queue[EventRecord] = asyncio.Queue() await pubsub.subscribe(self._channel(run_id)) async def reader() -> None: @@ -97,7 +146,13 @@ async def reader() -> None: continue try: payload = json.loads(message["data"]) - await queue.put(RunEvent(**payload)) + if "id" in payload and "event" in payload: + event_id = payload["id"] + event = RunEvent(**payload["event"]) + else: + event = RunEvent(**payload) + event_id = "" + await queue.put((event_id, event)) except Exception: # pragma: no cover - defensive logger.exception("event_decode_failed") diff --git a/backend/tests/test_event_replay.py b/backend/tests/test_event_replay.py new file mode 100644 index 0000000..6157d8e --- /dev/null +++ b/backend/tests/test_event_replay.py @@ -0,0 +1,186 @@ +"""Tests for SSE event replay via Last-Event-ID.""" + +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime + +import fakeredis.aioredis as fakeredis +import pytest +from httpx import ASGITransport, AsyncClient + +from app.events.bus import InMemoryEventBus, RedisEventBus +from app.schemas.run import RunEvent + + +def _event(event_type: str, run_id: str, **data: object) -> RunEvent: + return RunEvent( + type=event_type, + run_id=run_id, + at=datetime.now(UTC), + data=dict(data), + ) + + +async def _collect_sse( + client: AsyncClient, + path: str, + *, + headers: dict[str, str] | None = None, + max_frames: int = 20, +) -> list[dict[str, str | None]]: + frames: list[dict[str, str | None]] = [] + async with client.stream( + "GET", + path, + headers=headers or {}, + timeout=5.0, + ) as response: + assert response.status_code == 200 + event_name: str | None = None + event_id: str | None = None + data_lines: list[str] = [] + + async for line in response.aiter_lines(): + if line == "": + if event_name is not None or data_lines: + frames.append( + { + "id": event_id, + "event": event_name, + "data": "\n".join(data_lines) if data_lines else None, + } + ) + event_name = None + event_id = None + data_lines = [] + if len(frames) >= max_frames: + break + continue + + if line.startswith(":"): + continue + if line.startswith("id:"): + event_id = line[3:].strip() + elif line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].strip()) + + return frames + + +@pytest.mark.asyncio +async def test_sse_replay_after_last_event_id(monkeypatch): + import app.events.bus as bus_module + + bus = InMemoryEventBus() + monkeypatch.setattr(bus_module, "_bus", bus) + + from app.main import app + + run_id = "run-replay-1" + await bus.publish(_event("run.created", run_id)) + await bus.publish(_event("run.started", run_id)) + await bus.publish(_event("step.started", run_id, index=0, node="plan")) + await bus.publish(_event("step.completed", run_id, index=0)) + await bus.publish(_event("run.completed", run_id, output={"reply": "ok"})) + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with app.router.lifespan_context(app): + replay = await _collect_sse( + client, + f"/v1/events/{run_id}", + headers={"Last-Event-ID": "1"}, + max_frames=10, + ) + + event_names = [frame["event"] for frame in replay if frame["event"] != "ping"] + assert event_names == [ + "run.started", + "step.started", + "step.completed", + "run.completed", + ] + assert all(frame["id"] for frame in replay if frame["event"] != "ping") + + +@pytest.mark.asyncio +async def test_sse_replay_query_param(monkeypatch): + import app.events.bus as bus_module + + bus = InMemoryEventBus() + monkeypatch.setattr(bus_module, "_bus", bus) + + from app.main import app + + run_id = "run-replay-2" + await bus.publish(_event("run.created", run_id)) + await bus.publish(_event("run.started", run_id)) + await bus.publish(_event("run.completed", run_id, output={})) + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with app.router.lifespan_context(app): + second = await _collect_sse( + client, + f"/v1/events/{run_id}?last_event_id=1", + max_frames=5, + ) + + event_names = [frame["event"] for frame in second if frame["event"] != "ping"] + assert event_names == ["run.started", "run.completed"] + + +@pytest.mark.asyncio +async def test_sse_live_stream_includes_event_ids(monkeypatch): + import app.events.bus as bus_module + + bus = InMemoryEventBus() + monkeypatch.setattr(bus_module, "_bus", bus) + + from app.main import app + + run_id = "run-live-ids" + + async def publish_run() -> None: + await asyncio.sleep(0.05) + await bus.publish(_event("run.created", run_id)) + await bus.publish(_event("run.completed", run_id, output={})) + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + async with app.router.lifespan_context(app): + publisher = asyncio.create_task(publish_run()) + frames = await _collect_sse( + client, + f"/v1/events/{run_id}", + max_frames=5, + ) + await publisher + + ids = [frame["id"] for frame in frames if frame["event"] != "ping"] + assert ids == ["1", "2"] + + +@pytest.mark.asyncio +async def test_redis_event_bus_replay_with_fakeredis(): + redis = fakeredis.FakeRedis(decode_responses=True) + bus = RedisEventBus.__new__(RedisEventBus) + bus._redis = redis + bus._channel_prefix = "agentflow:run:" + bus._stream_suffix = ":log" + bus._stream_max_len = 1000 + + run_id = "run-redis-1" + id1 = await bus.publish(_event("run.created", run_id)) + id2 = await bus.publish(_event("run.started", run_id)) + id3 = await bus.publish(_event("run.completed", run_id, output={})) + + replayed = [event_id async for event_id, _ in bus.replay(run_id, id1)] + assert replayed == [id2, id3] + assert id1 not in replayed + + full = [event.type async for _, event in bus.replay(run_id, None)] + assert full == ["run.created", "run.started", "run.completed"] diff --git a/docs/api-contract.md b/docs/api-contract.md index 454bd85..d21bd6b 100644 --- a/docs/api-contract.md +++ b/docs/api-contract.md @@ -133,6 +133,11 @@ reaches a terminal state (`run.completed`, `run.failed`, `run.cancelled`) or the client disconnects. The server sends `event: ping` heartbeats roughly every 15 seconds. +On reconnect, send the standard `Last-Event-ID` header or `?last_event_id=` +query parameter to replay missed events before resuming the live stream. Each +persisted frame includes an SSE `id` field; the log is stored in Redis Stream +`agentflow:run:{run_id}:log` (or in-memory when Redis is unset). + Each SSE frame uses the event type as the SSE `event` field and the JSON payload below as `data`: diff --git a/frontend/lib/run-event-connection.ts b/frontend/lib/run-event-connection.ts index c551132..00b4001 100644 --- a/frontend/lib/run-event-connection.ts +++ b/frontend/lib/run-event-connection.ts @@ -13,10 +13,11 @@ export type RunEventConnectionStatus = export interface StoredRunEvent extends RunEvent { clientSeq: number; + eventId?: string; } export interface RunEventConnectionHandlers { - onEvent?: (event: RunEvent, clientSeq: number) => void; + onEvent?: (event: RunEvent, clientSeq: number, eventId?: string) => void; onTerminal?: (event: RunEvent) => void; onStatusChange?: (status: RunEventConnectionStatus) => void; } @@ -24,6 +25,13 @@ export interface RunEventConnectionHandlers { const INITIAL_RECONNECT_MS = 1_000; const MAX_RECONNECT_MS = 30_000; +function streamUrl(runId: string, lastEventId: string | null): string { + const base = eventStreamUrl(runId); + if (!lastEventId) return base; + const params = new URLSearchParams({ last_event_id: lastEventId }); + return `${base}?${params.toString()}`; +} + export function createRunEventConnection( runId: string, handlers: RunEventConnectionHandlers, @@ -34,6 +42,7 @@ export function createRunEventConnection( let reconnectAttempt = 0; let terminal = false; let clientSeq = 0; + let lastEventId: string | null = null; const setStatus = (status: RunEventConnectionStatus) => { if (!cancelled) handlers.onStatusChange?.(status); @@ -58,7 +67,7 @@ export function createRunEventConnection( clearReconnectTimer(); closeSource(); setStatus("closed"); - onTerminalRef.current?.(event); + handlers.onTerminal?.(event); }; const scheduleReconnect = () => { @@ -77,10 +86,13 @@ export function createRunEventConnection( const onMessage = (event: MessageEvent) => { if (cancelled || terminal) return; + if (event.lastEventId) { + lastEventId = event.lastEventId; + } try { const data: RunEvent = JSON.parse(event.data); clientSeq += 1; - onEventRef.current?.(data, clientSeq); + handlers.onEvent?.(data, clientSeq, event.lastEventId || undefined); if (isTerminalRunEvent(data.type)) { finishTerminal(data); } @@ -95,7 +107,7 @@ export function createRunEventConnection( closeSource(); setStatus(reconnectAttempt === 0 ? "connecting" : "reconnecting"); - const next = new EventSource(eventStreamUrl(runId)); + const next = new EventSource(streamUrl(runId, lastEventId)); source = next; next.addEventListener("open", () => { diff --git a/frontend/lib/useRunEventSource.ts b/frontend/lib/useRunEventSource.ts index e94cbf7..f326655 100644 --- a/frontend/lib/useRunEventSource.ts +++ b/frontend/lib/useRunEventSource.ts @@ -47,8 +47,8 @@ export function useRunEventSource({ const connection = createRunEventConnection(runId, { onStatusChange: setStatus, - onEvent: (event, clientSeq) => { - setEvents((prev) => [...prev, { ...event, clientSeq }]); + onEvent: (event, clientSeq, eventId) => { + setEvents((prev) => [...prev, { ...event, clientSeq, eventId }]); onEventRef.current?.(event); }, onTerminal: (event) => {