Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ public void setTtlSeconds(long ttlSeconds) {

public static class Events {
private String channelPrefix = "agentflow:run:";
private String streamSuffix = ":log";
private int streamMaxLen = 10_000;
private long sseHeartbeatSeconds = 15;

public String getChannelPrefix() {
Expand All @@ -145,6 +147,22 @@ public void setChannelPrefix(String channelPrefix) {
this.channelPrefix = channelPrefix;
}

public String getStreamSuffix() {
return streamSuffix;
}

public void setStreamSuffix(String streamSuffix) {
this.streamSuffix = streamSuffix;
}

public int getStreamMaxLen() {
return streamMaxLen;
}

public void setStreamMaxLen(int streamMaxLen) {
this.streamMaxLen = streamMaxLen;
}

public long getSseHeartbeatSeconds() {
return sseHeartbeatSeconds;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

Expand All @@ -19,7 +21,12 @@ public EventsController(EventStreamService events) {
}

@GetMapping(value = "/{runId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter stream(@PathVariable String runId) {
return events.subscribe(runId);
public SseEmitter stream(
@PathVariable String runId,
@RequestHeader(value = "Last-Event-ID", required = false) String lastEventId,
@RequestParam(value = "last_event_id", required = false) String lastEventIdParam) {
String after =
lastEventId != null && !lastEventId.isBlank() ? lastEventId : lastEventIdParam;
return events.subscribe(runId, after);
}
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
package io.agentflow.api.service;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.agentflow.api.config.AgentflowProperties;
import io.agentflow.api.dto.RunEvent;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.domain.Range;
import org.springframework.data.redis.connection.stream.MapRecord;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

/**
* Bridges the worker's Redis pub/sub stream to per-run SSE emitters. Each
* subscriber gets its own {@link MessageListener} on the channel; lifecycle
* is tied to the {@link SseEmitter} (completion/timeout/error all remove
* the listener).
* subscriber gets its own {@link org.springframework.data.redis.connection.MessageListener}
* on the channel; lifecycle is tied to the {@link SseEmitter}.
*
* <p>Events are also persisted in a Redis Stream per run so clients can
* reconnect with {@code Last-Event-ID} and receive any missed frames.
*/
@Service
public class EventStreamService {
Expand All @@ -34,6 +40,7 @@ public class EventStreamService {
Set.of("run.completed", "run.failed", "run.cancelled");

private final RedisMessageListenerContainer listenerContainer;
private final StringRedisTemplate redis;
private final ObjectMapper mapper;
private final AgentflowProperties props;
private final ScheduledExecutorService heartbeatExecutor =
Expand All @@ -43,30 +50,55 @@ public class EventStreamService {
t.setDaemon(true);
return t;
});
private final Map<SseEmitter, MessageListener> active = new ConcurrentHashMap<>();
private final Map<SseEmitter, org.springframework.data.redis.connection.MessageListener> active =
new ConcurrentHashMap<>();

public EventStreamService(
RedisMessageListenerContainer listenerContainer,
StringRedisTemplate redis,
ObjectMapper mapper,
AgentflowProperties props) {
this.listenerContainer = listenerContainer;
this.redis = redis;
this.mapper = mapper;
this.props = props;
}

public SseEmitter subscribe(String runId) {
return subscribe(runId, null);
}

public SseEmitter subscribe(String runId, String afterEventId) {
SseEmitter emitter = new SseEmitter(0L);
ChannelTopic topic = new ChannelTopic(props.getEvents().getChannelPrefix() + runId);
String streamKey = streamKey(runId);
AtomicReference<String> lastSentId = new AtomicReference<>(afterEventId);

try {
if (!replay(emitter, streamKey, afterEventId, lastSentId)) {
return emitter;
}
} catch (Exception e) {
log.warn("Failed to replay SSE events for run {}", runId, e);
emitter.completeWithError(e);
return emitter;
}

MessageListener listener = (message, pattern) -> {
org.springframework.data.redis.connection.MessageListener listener = (message, pattern) -> {
try {
String body = new String(message.getBody());
RunEvent event = mapper.readValue(body, RunEvent.class);
emitter.send(
SseEmitter.event()
.name(event.type())
.data(body));
if (TERMINAL_TYPES.contains(event.type())) {
DeliveredEvent delivered = parseEnvelope(new String(message.getBody()));
if (delivered.eventId() != null
&& lastSentId.get() != null
&& !isAfter(delivered.eventId(), lastSentId.get())) {
return;
}
if (!send(emitter, delivered)) {
return;
}
if (delivered.eventId() != null) {
lastSentId.set(delivered.eventId());
}
if (TERMINAL_TYPES.contains(delivered.event().type())) {
emitter.complete();
}
} catch (IllegalStateException ignored) {
Expand All @@ -82,6 +114,12 @@ public SseEmitter subscribe(String runId) {
listenerContainer.addMessageListener(listener, topic);
active.put(emitter, listener);

try {
replay(emitter, streamKey, lastSentId.get(), lastSentId);
} catch (Exception e) {
log.warn("Failed catch-up replay for run {}", runId, e);
}

long heartbeatSeconds = Math.max(1, props.getEvents().getSseHeartbeatSeconds());
ScheduledFuture<?> heartbeat = heartbeatExecutor.scheduleAtFixedRate(
() -> {
Expand All @@ -97,7 +135,7 @@ public SseEmitter subscribe(String runId) {

Runnable cleanup = () -> {
heartbeat.cancel(true);
MessageListener removed = active.remove(emitter);
org.springframework.data.redis.connection.MessageListener removed = active.remove(emitter);
if (removed != null) {
listenerContainer.removeMessageListener(removed, topic);
}
Expand All @@ -108,4 +146,76 @@ public SseEmitter subscribe(String runId) {

return emitter;
}

private String streamKey(String runId) {
return props.getEvents().getChannelPrefix() + runId + props.getEvents().getStreamSuffix();
}

/** @return {@code false} when a terminal event was replayed. */
private boolean replay(
SseEmitter emitter,
String streamKey,
String afterEventId,
AtomicReference<String> lastSentId)
throws IOException {
Range<String> range =
afterEventId == null || afterEventId.isBlank()
? Range.unbounded()
: Range.of(Range.Bound.exclusive(afterEventId), Range.Bound.unbounded());

List<MapRecord<String, Object, Object>> records =
redis.opsForStream().range(streamKey, range);
if (records == null) {
return true;
}

for (MapRecord<String, Object, Object> record : records) {
Object payload = record.getValue().get("payload");
if (payload == null) {
continue;
}
RunEvent event = mapper.readValue(payload.toString(), RunEvent.class);
DeliveredEvent delivered = new DeliveredEvent(record.getId().getValue(), event);
if (!send(emitter, delivered)) {
return false;
}
lastSentId.set(delivered.eventId());
if (TERMINAL_TYPES.contains(event.type())) {
emitter.complete();
return false;
}
}
return true;
}

private boolean send(SseEmitter emitter, DeliveredEvent delivered) throws IOException {
String body = mapper.writeValueAsString(delivered.event());
SseEmitter.SseEventBuilder builder =
SseEmitter.event().name(delivered.event().type()).data(body);
if (delivered.eventId() != null) {
builder.id(delivered.eventId());
}
emitter.send(builder);
return true;
}

private DeliveredEvent parseEnvelope(String body) throws IOException {
JsonNode root = mapper.readTree(body);
if (root.has("id") && root.has("event")) {
String eventId = root.get("id").asText();
RunEvent event = mapper.treeToValue(root.get("event"), RunEvent.class);
return new DeliveredEvent(eventId, event);
}
return new DeliveredEvent(null, mapper.readValue(body, RunEvent.class));
}

private static boolean isAfter(String candidate, String lastId) {
try {
return Long.parseLong(candidate) > Long.parseLong(lastId);
} catch (NumberFormatException ex) {
return candidate.compareTo(lastId) > 0;
}
}

private record DeliveredEvent(String eventId, RunEvent event) {}
}
2 changes: 2 additions & 0 deletions backend-java/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ agentflow:
ttl-seconds: 86400
events:
channel-prefix: "agentflow:run:"
stream-suffix: ":log"
stream-max-len: 10000
sse-heartbeat-seconds: 15
otel:
enabled: ${AGENTFLOW_OTEL_ENABLED:false}
Expand Down
55 changes: 50 additions & 5 deletions backend/app/api/v1/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Clients subscribe with `GET /v1/events/{run_id}`. The stream stays open
until the run reaches a terminal state (`succeeded`, `failed`, `cancelled`)
or the client disconnects.

On reconnect, clients may send the standard `Last-Event-ID` header (or the
`last_event_id` query parameter) to receive any events published while
disconnected before resuming the live stream.
"""

from __future__ import annotations
Expand All @@ -14,35 +18,76 @@
from sse_starlette.sse import EventSourceResponse

from app.events import EventBus, get_event_bus
from app.events.bus import _is_after
from app.schemas.run import RunEvent

router = APIRouter(prefix="/events", tags=["events"])

_TERMINAL_TYPES = {"run.completed", "run.failed", "run.cancelled"}


def _resolve_last_event_id(request: Request) -> str | None:
header = request.headers.get("last-event-id")
if header:
return header.strip()
query = request.query_params.get("last_event_id")
if query:
return query.strip()
return None


def _sse_frame(event_id: str, event: RunEvent) -> dict[str, str]:
return {
"id": event_id,
"event": event.type,
"data": event.model_dump_json(),
}


@router.get("/{run_id}")
async def stream_run_events(
run_id: str,
request: Request,
bus: EventBus = Depends(get_event_bus),
) -> EventSourceResponse:
after_id = _resolve_last_event_id(request)

async def generator() -> AsyncIterator[dict[str, str]]:
last_id = after_id

async for event_id, event in bus.replay(run_id, after_id):
if await request.is_disconnected():
return
yield _sse_frame(event_id, event)
last_id = event_id
if event.type in _TERMINAL_TYPES:
return

async with bus.subscribe(run_id) as queue:
async for event_id, event in bus.replay(run_id, last_id):
if await request.is_disconnected():
return
yield _sse_frame(event_id, event)
last_id = event_id
if event.type in _TERMINAL_TYPES:
return

while True:
if await request.is_disconnected():
break

try:
event: RunEvent = await asyncio.wait_for(queue.get(), timeout=15.0)
event_id, event = await asyncio.wait_for(queue.get(), timeout=15.0)
except TimeoutError:
yield {"event": "ping", "data": "{}"}
continue

yield {
"event": event.type,
"data": event.model_dump_json(),
}
if event_id and last_id and not _is_after(event_id, last_id):
continue
if event_id:
last_id = event_id

yield _sse_frame(event_id or "0", event)

if event.type in _TERMINAL_TYPES:
break
Expand Down
13 changes: 13 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ class Settings(BaseSettings):
default=None,
description="Optional redis URL. When unset, the in-memory event bus is used.",
)
event_channel_prefix: str = Field(
default="agentflow:run:",
description="Redis pub/sub channel prefix for live run events.",
)
event_stream_suffix: str = Field(
default=":log",
description="Suffix appended to the channel prefix for the replay stream key.",
)
event_stream_max_len: int = Field(
default=10_000,
ge=100,
description="Approximate max entries kept in each run's Redis event stream.",
)
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"

default_adapter: str = Field(
Expand Down
Loading
Loading