-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix: gateway resume can silently lose events - add integrity checks #2155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -112,10 +112,45 @@ def add_event(self, event: GatewayEvent) -> int: | |||||||||||||||
| self._events = self._events[-self._max_messages:] | ||||||||||||||||
| return self._event_cursor | ||||||||||||||||
|
|
||||||||||||||||
| def get_oldest_cursor(self) -> int: | ||||||||||||||||
| """Get the oldest event cursor still retained in the buffer. | ||||||||||||||||
|
|
||||||||||||||||
| When the buffer is empty, returns the current cursor position, | ||||||||||||||||
| which correctly indicates that any cursor < _event_cursor would | ||||||||||||||||
| require resync (since no events are retained). | ||||||||||||||||
| """ | ||||||||||||||||
| if self._events: | ||||||||||||||||
| return self._events[0].data.get('cursor', self._event_cursor) | ||||||||||||||||
| return self._event_cursor | ||||||||||||||||
|
|
||||||||||||||||
| def get_events_since(self, cursor: int) -> List[GatewayEvent]: | ||||||||||||||||
| """Get events since the given cursor.""" | ||||||||||||||||
| return [e for e in self._events if e.data.get('cursor', 0) > cursor] | ||||||||||||||||
|
|
||||||||||||||||
| def check_resync_required(self, since_cursor: Optional[int]) -> bool: | ||||||||||||||||
| """Check if resync is required based on the requested cursor.""" | ||||||||||||||||
| if since_cursor is None: | ||||||||||||||||
| return False | ||||||||||||||||
| oldest_cursor = self.get_oldest_cursor() | ||||||||||||||||
| return since_cursor < oldest_cursor | ||||||||||||||||
|
|
||||||||||||||||
| def get_snapshot(self) -> Dict[str, Any]: | ||||||||||||||||
| """Get a snapshot of the current session state for resync.""" | ||||||||||||||||
| return { | ||||||||||||||||
| "session_id": self._session_id, | ||||||||||||||||
| "agent_id": self._agent_id, | ||||||||||||||||
| "state": dict(self._state), | ||||||||||||||||
| "messages": [{ | ||||||||||||||||
| "content": msg.content, | ||||||||||||||||
| "sender_id": msg.sender_id, | ||||||||||||||||
| "session_id": msg.session_id, | ||||||||||||||||
| "message_id": msg.message_id, | ||||||||||||||||
| "timestamp": msg.timestamp, | ||||||||||||||||
| "metadata": msg.metadata, | ||||||||||||||||
| } for msg in self._messages], | ||||||||||||||||
| "event_cursor": self._event_cursor, | ||||||||||||||||
| } | ||||||||||||||||
|
Comment on lines
+137
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The snapshot sent on resync includes |
||||||||||||||||
|
|
||||||||||||||||
| def to_dict(self) -> Dict[str, Any]: | ||||||||||||||||
| """Serialize session to dictionary for persistence.""" | ||||||||||||||||
| return { | ||||||||||||||||
|
|
@@ -956,7 +991,18 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> | |||||||||||||||
| if agent_id and agent_id in self._agents: | ||||||||||||||||
| # Support reconnection with existing session | ||||||||||||||||
| session_id = data.get("session_id") # Optional: existing session to resume | ||||||||||||||||
| since_cursor = data.get("since") # Optional: cursor for event replay | ||||||||||||||||
| # Parse and validate the since parameter | ||||||||||||||||
| since_raw = data.get("since") # Optional: cursor for event replay | ||||||||||||||||
| since_cursor = None | ||||||||||||||||
| if since_raw is not None: | ||||||||||||||||
| try: | ||||||||||||||||
| since_cursor = int(since_raw) | ||||||||||||||||
| except (TypeError, ValueError): | ||||||||||||||||
| await self._send_to_client(client_id, { | ||||||||||||||||
| "type": "error", | ||||||||||||||||
| "message": "Invalid 'since' cursor. Must be an integer.", | ||||||||||||||||
| }) | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| # Resume or create session | ||||||||||||||||
| session, replay_events = self.resume_or_create_session( | ||||||||||||||||
|
|
@@ -968,21 +1014,39 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> | |||||||||||||||
|
|
||||||||||||||||
| self._client_sessions[client_id] = session.session_id | ||||||||||||||||
|
|
||||||||||||||||
| # Send join confirmation | ||||||||||||||||
| # Check if resync is required | ||||||||||||||||
| resync_required = session.check_resync_required(since_cursor) | ||||||||||||||||
| oldest_cursor = session.get_oldest_cursor() | ||||||||||||||||
|
Comment on lines
+1017
to
+1019
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normalize
Suggested fix- since_cursor = data.get("since") # Optional: cursor for event replay
+ since_raw = data.get("since") # Optional: cursor for event replay
+ since_cursor: Optional[int] = None
+ if since_raw is not None:
+ try:
+ since_cursor = int(since_raw)
+ except (TypeError, ValueError):
+ await self._send_to_client(client_id, {
+ "type": "error",
+ "message": "Invalid 'since' cursor. Send an integer from your last received cursor.",
+ })
+ return🤖 Prompt for AI Agents |
||||||||||||||||
|
|
||||||||||||||||
| # Send join confirmation with integrity check info | ||||||||||||||||
| await self._send_to_client(client_id, { | ||||||||||||||||
| "type": "joined", | ||||||||||||||||
| "session_id": session.session_id, | ||||||||||||||||
| "agent_id": agent_id, | ||||||||||||||||
| "resumed": session._was_resumed, | ||||||||||||||||
| "cursor": session._event_cursor, | ||||||||||||||||
| "oldest_cursor": oldest_cursor, | ||||||||||||||||
| "resync_required": resync_required, | ||||||||||||||||
| }) | ||||||||||||||||
|
|
||||||||||||||||
| # Replay missed events if any | ||||||||||||||||
| for event in replay_events: | ||||||||||||||||
| if resync_required: | ||||||||||||||||
| # Send authoritative snapshot instead of partial replay | ||||||||||||||||
| snapshot = session.get_snapshot() | ||||||||||||||||
| await self._send_to_client(client_id, { | ||||||||||||||||
| "type": "replay", | ||||||||||||||||
| "event": event.to_dict(), | ||||||||||||||||
| "type": "snapshot", | ||||||||||||||||
| "state": snapshot, | ||||||||||||||||
| }) | ||||||||||||||||
| else: | ||||||||||||||||
| # Replay missed events if any | ||||||||||||||||
| for event in replay_events: | ||||||||||||||||
| event_data = event.to_dict() | ||||||||||||||||
| # Include top-level sequence number from the cursor | ||||||||||||||||
| seq = event.data.get('cursor', 0) | ||||||||||||||||
| await self._send_to_client(client_id, { | ||||||||||||||||
| "type": "replay", | ||||||||||||||||
| "event": event_data, | ||||||||||||||||
| "seq": seq, | ||||||||||||||||
| }) | ||||||||||||||||
| else: | ||||||||||||||||
| await self._send_to_client(client_id, { | ||||||||||||||||
| "type": "error", | ||||||||||||||||
|
|
@@ -1166,7 +1230,14 @@ async def _send_to_client(self, client_id: str, data: Dict[str, Any]) -> None: | |||||||||||||||
| if ws: | ||||||||||||||||
| try: | ||||||||||||||||
| # Track event in session BEFORE sending if it's a response or important event | ||||||||||||||||
| if data.get("type") in ["response", "message", "stream_end", "error"]: | ||||||||||||||||
| if data.get("type") in [ | ||||||||||||||||
| "response", | ||||||||||||||||
| "message", | ||||||||||||||||
| "stream_end", | ||||||||||||||||
| "error", | ||||||||||||||||
| "token_stream", | ||||||||||||||||
| "tool_call_stream", | ||||||||||||||||
| ]: | ||||||||||||||||
| session_id = self._client_sessions.get(client_id) | ||||||||||||||||
| if session_id: | ||||||||||||||||
| session = self._sessions.get(session_id) | ||||||||||||||||
|
|
@@ -1180,6 +1251,8 @@ async def _send_to_client(self, client_id: str, data: Dict[str, Any]) -> None: | |||||||||||||||
| cursor = session.add_event(event) | ||||||||||||||||
| # Add cursor to the data BEFORE sending | ||||||||||||||||
| data["cursor"] = cursor | ||||||||||||||||
| # Add top-level sequence number for integrity checking | ||||||||||||||||
| data["seq"] = cursor | ||||||||||||||||
|
Comment on lines
1253
to
+1255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||||
|
|
||||||||||||||||
| # Send ONCE with cursor already attached if applicable | ||||||||||||||||
| await ws.send_json(data) | ||||||||||||||||
|
|
@@ -1470,7 +1543,9 @@ def resume_or_create_session( | |||||||||||||||
|
|
||||||||||||||||
| Returns: | ||||||||||||||||
| Tuple of (session, replay_events) where replay_events are events | ||||||||||||||||
| that occurred after since_cursor | ||||||||||||||||
| that occurred after since_cursor. Note: Callers must check | ||||||||||||||||
| session.check_resync_required(since_cursor) before using replay_events, | ||||||||||||||||
| as the events may not include the full gap if buffer was trimmed. | ||||||||||||||||
| """ | ||||||||||||||||
| replay_events = [] | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.