diff --git a/src/praisonai-agents/praisonaiagents/gateway/protocols.py b/src/praisonai-agents/praisonaiagents/gateway/protocols.py index b44f0f226..1e04e274f 100644 --- a/src/praisonai-agents/praisonaiagents/gateway/protocols.py +++ b/src/praisonai-agents/praisonaiagents/gateway/protocols.py @@ -100,6 +100,19 @@ class GatewayEvent: timestamp: Event creation time source: Source identifier (agent_id, client_id, etc.) target: Target identifier (optional, for directed events) + + Wire Protocol Extensions: + When events are sent over the gateway, additional fields are added: + - seq: Top-level monotonic sequence number for gap detection + - cursor: Event cursor position (also stored in data['cursor']) + + Resume Protocol: + The 'joined' acknowledgment includes: + - cursor: Current head cursor position + - oldest_cursor: Oldest event still in buffer + - resync_required: True if requested 'since' is below oldest_cursor + + When resync_required=true, a 'snapshot' message follows with full state. """ type: Union[EventType, str] diff --git a/src/praisonai/praisonai/gateway/server.py b/src/praisonai/praisonai/gateway/server.py index 28c551725..07a79b67b 100644 --- a/src/praisonai/praisonai/gateway/server.py +++ b/src/praisonai/praisonai/gateway/server.py @@ -112,10 +112,45 @@ def add_event(self, event: GatewayEvent) -> int: self._events = self._events[-self._max_messages:] return self._event_cursor + def get_oldest_cursor(self) -> int: + """Get the oldest event cursor still retained in the buffer. + + When the buffer is empty, returns the current cursor position, + which correctly indicates that any cursor < _event_cursor would + require resync (since no events are retained). + """ + if self._events: + return self._events[0].data.get('cursor', self._event_cursor) + return self._event_cursor + def get_events_since(self, cursor: int) -> List[GatewayEvent]: """Get events since the given cursor.""" return [e for e in self._events if e.data.get('cursor', 0) > cursor] + def check_resync_required(self, since_cursor: Optional[int]) -> bool: + """Check if resync is required based on the requested cursor.""" + if since_cursor is None: + return False + oldest_cursor = self.get_oldest_cursor() + return since_cursor < oldest_cursor + + def get_snapshot(self) -> Dict[str, Any]: + """Get a snapshot of the current session state for resync.""" + return { + "session_id": self._session_id, + "agent_id": self._agent_id, + "state": dict(self._state), + "messages": [{ + "content": msg.content, + "sender_id": msg.sender_id, + "session_id": msg.session_id, + "message_id": msg.message_id, + "timestamp": msg.timestamp, + "metadata": msg.metadata, + } for msg in self._messages], + "event_cursor": self._event_cursor, + } + def to_dict(self) -> Dict[str, Any]: """Serialize session to dictionary for persistence.""" return { @@ -956,7 +991,18 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> if agent_id and agent_id in self._agents: # Support reconnection with existing session session_id = data.get("session_id") # Optional: existing session to resume - since_cursor = data.get("since") # Optional: cursor for event replay + # Parse and validate the since parameter + since_raw = data.get("since") # Optional: cursor for event replay + since_cursor = None + if since_raw is not None: + try: + since_cursor = int(since_raw) + except (TypeError, ValueError): + await self._send_to_client(client_id, { + "type": "error", + "message": "Invalid 'since' cursor. Must be an integer.", + }) + return # Resume or create session session, replay_events = self.resume_or_create_session( @@ -968,21 +1014,39 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> self._client_sessions[client_id] = session.session_id - # Send join confirmation + # Check if resync is required + resync_required = session.check_resync_required(since_cursor) + oldest_cursor = session.get_oldest_cursor() + + # Send join confirmation with integrity check info await self._send_to_client(client_id, { "type": "joined", "session_id": session.session_id, "agent_id": agent_id, "resumed": session._was_resumed, "cursor": session._event_cursor, + "oldest_cursor": oldest_cursor, + "resync_required": resync_required, }) - # Replay missed events if any - for event in replay_events: + if resync_required: + # Send authoritative snapshot instead of partial replay + snapshot = session.get_snapshot() await self._send_to_client(client_id, { - "type": "replay", - "event": event.to_dict(), + "type": "snapshot", + "state": snapshot, }) + else: + # Replay missed events if any + for event in replay_events: + event_data = event.to_dict() + # Include top-level sequence number from the cursor + seq = event.data.get('cursor', 0) + await self._send_to_client(client_id, { + "type": "replay", + "event": event_data, + "seq": seq, + }) else: await self._send_to_client(client_id, { "type": "error", @@ -1166,7 +1230,14 @@ async def _send_to_client(self, client_id: str, data: Dict[str, Any]) -> None: if ws: try: # Track event in session BEFORE sending if it's a response or important event - if data.get("type") in ["response", "message", "stream_end", "error"]: + if data.get("type") in [ + "response", + "message", + "stream_end", + "error", + "token_stream", + "tool_call_stream", + ]: session_id = self._client_sessions.get(client_id) if session_id: session = self._sessions.get(session_id) @@ -1180,6 +1251,8 @@ async def _send_to_client(self, client_id: str, data: Dict[str, Any]) -> None: cursor = session.add_event(event) # Add cursor to the data BEFORE sending data["cursor"] = cursor + # Add top-level sequence number for integrity checking + data["seq"] = cursor # Send ONCE with cursor already attached if applicable await ws.send_json(data) @@ -1470,7 +1543,9 @@ def resume_or_create_session( Returns: Tuple of (session, replay_events) where replay_events are events - that occurred after since_cursor + that occurred after since_cursor. Note: Callers must check + session.check_resync_required(since_cursor) before using replay_events, + as the events may not include the full gap if buffer was trimmed. """ replay_events = []