Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/praisonai-agents/praisonaiagents/gateway/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@
runtime_checkable,
)

# Gateway protocol versioning constants
GATEWAY_PROTOCOL_VERSION = 1
MIN_CLIENT_PROTOCOL_VERSION = 1

if TYPE_CHECKING:
from praisonai.gateway.pairing import PairedChannel
from ..agent import Agent
from ..bots.presentation import MessagePresentation


class ConnectErrorCode(str, Enum):
"""Structured error codes for connection failures."""
AUTH_REQUIRED = "auth_required"
AUTH_UNAUTHORIZED = "auth_unauthorized"
PROTOCOL_UNSUPPORTED = "protocol_unsupported"
PAIRING_REQUIRED = "pairing_required"
AGENT_NOT_FOUND = "agent_not_found"


class EventType(str, Enum):
"""Standard gateway event types."""

Expand Down Expand Up @@ -87,6 +100,65 @@ class EventType(str, Enum):
# Polling events
POLL_REQUEST = "poll_request"
POLL_RESPONSE = "poll_response"

# Handshake events
HELLO = "hello"
HELLO_OK = "hello_ok"
HELLO_ERROR = "hello_error"


@dataclass
class HelloParams:
"""Parameters for initiating a versioned handshake.

Attributes:
agent_id: The agent to connect to
protocol_min: Minimum protocol version the client supports
protocol_max: Maximum protocol version the client supports
capabilities: Optional list of capability tokens the client supports
session_id: Optional session to resume
since: Optional cursor for event replay
"""
agent_id: str
protocol_min: int
protocol_max: int
capabilities: List[str] = field(default_factory=list)
session_id: Optional[str] = None
since: Optional[int] = None


@dataclass
class HelloResult:
"""Result of a successful handshake negotiation.

Attributes:
protocol: The negotiated protocol version
features: Supported methods and events
policy: Gateway policy limits (max_payload, heartbeat_ms, etc.)
session_id: The session ID for this connection
resumed: Whether an existing session was resumed
cursor: Current event cursor position
"""
protocol: int
features: Dict[str, List[str]] # {"methods": [...], "events": [...]}
policy: Dict[str, int] # {"max_payload": ..., "heartbeat_ms": ...}
session_id: str
resumed: bool
cursor: int


@dataclass
class HelloError:
"""Error response for failed handshake.

Attributes:
code: Structured error code
message: Human-readable error message
next_action: Suggested next action (e.g., "upgrade_client", "pair_device")
"""
code: ConnectErrorCode
message: str
next_action: Optional[str] = None


@dataclass
Expand Down
185 changes: 183 additions & 2 deletions src/praisonai/praisonai/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
GatewayMessage,
EventType,
)
from praisonaiagents.gateway.protocols import (
ConnectErrorCode,
HelloResult,
HelloError,
GATEWAY_PROTOCOL_VERSION,
MIN_CLIENT_PROTOCOL_VERSION,
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
from praisonaiagents.session.protocols import SessionStoreProtocol
from praisonaiagents.session.store import DefaultSessionStore

Expand Down Expand Up @@ -951,7 +958,181 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) ->
"""Handle a message from a client."""
msg_type = data.get("type", "message")

if msg_type == "join":
# Handle versioned handshake
if msg_type == "hello":
agent_id = data.get("agent_id")

# Check if agent exists
if not agent_id or agent_id not in self._agents:
error = HelloError(
code=ConnectErrorCode.AGENT_NOT_FOUND,
message=f"Agent not found: {agent_id}",
next_action="check_agent_id"
)
await self._send_to_client(client_id, {
"type": "hello_error",
"code": error.code.value,
"message": error.message,
"next": error.next_action,
})
return

# Parse protocol version from client
# Support both HelloParams format (protocol_min/max as direct fields)
# and legacy format (nested under protocol dict)
if "protocol_min" in data or "protocol_max" in data:
# HelloParams format
client_min = data.get("protocol_min", 1)
client_max = data.get("protocol_max", 1)
else:
# Legacy format or missing
protocol_info = data.get("protocol", {})
if isinstance(protocol_info, dict):
client_min = protocol_info.get("min", 1)
client_max = protocol_info.get("max", 1)
else:
# Backwards compatibility: treat missing protocol as v1
client_min = client_max = 1

# Negotiate protocol version
if client_max < MIN_CLIENT_PROTOCOL_VERSION:
error = HelloError(
code=ConnectErrorCode.PROTOCOL_UNSUPPORTED,
message=f"Protocol version {client_max} is too old, minimum required is {MIN_CLIENT_PROTOCOL_VERSION}",
next_action="upgrade_client"
)
await self._send_to_client(client_id, {
"type": "hello_error",
"code": error.code.value,
"message": error.message,
"next": error.next_action,
})
return
Comment on lines +983 to +1010

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Null protocol_min/protocol_max causes unhandled TypeError

data.get("protocol_min", 1) returns None — not 1 — when the client explicitly sends "protocol_min": null, because .get() only substitutes its default when the key is absent. The same applies to protocol_max. The first comparison client_max < MIN_CLIENT_PROTOCOL_VERSION then raises TypeError: '<' not supported between instances of 'NoneType' and 'int', which propagates out of the hello handler and is swallowed by the outer WebSocket loop, silently dropping the connection instead of sending a hello_error to the client. This is the same class of null-vs-absent bug that was fixed for caps on line 1033 but was not applied here. Add a type/null guard immediately after extracting client_min/client_max, defaulting or rejecting non-integer values before any comparison.


if client_min > GATEWAY_PROTOCOL_VERSION:
error = HelloError(
code=ConnectErrorCode.PROTOCOL_UNSUPPORTED,
message=f"Protocol version {client_min} is too new, server supports up to {GATEWAY_PROTOCOL_VERSION}",
next_action="use_older_client"
)
await self._send_to_client(client_id, {
"type": "hello_error",
"code": error.code.value,
"message": error.message,
"next": error.next_action,
})
return

# Select the highest mutually supported version
negotiated_version = min(client_max, GATEWAY_PROTOCOL_VERSION)

# Get client capabilities
# Support both HelloParams format (capabilities) and legacy format (caps)
client_caps = data.get("capabilities", data.get("caps", []))
# Guard against null/None values
if client_caps is None or not isinstance(client_caps, list):
client_caps = []

Comment on lines +980 to +1035

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Align the wire parser with HelloParams.

Line 982 reads protocol.{min,max} and Line 1023 reads caps, but the shared contract defines protocol_min, protocol_max, and capabilities. SDK clients serializing HelloParams will be treated as protocol v1 with no capabilities.

🔧 Proposed contract-compatible parser
-            protocol_info = data.get("protocol", {})
-            if isinstance(protocol_info, dict):
-                client_min = protocol_info.get("min", 1)
-                client_max = protocol_info.get("max", 1)
-            else:
-                # Backwards compatibility: treat missing protocol as v1
-                client_min = client_max = 1
+            protocol_info = data.get("protocol")
+            if "protocol_min" in data or "protocol_max" in data:
+                client_min = data.get("protocol_min", 1)
+                client_max = data.get("protocol_max", 1)
+            elif isinstance(protocol_info, dict):
+                client_min = protocol_info.get("min", 1)
+                client_max = protocol_info.get("max", 1)
+            else:
+                # Backwards compatibility: treat missing protocol as v1
+                client_min = client_max = 1
+
+            try:
+                client_min = int(client_min)
+                client_max = int(client_max)
+            except (TypeError, ValueError):
+                error = HelloError(
+                    code=ConnectErrorCode.PROTOCOL_UNSUPPORTED,
+                    message="Protocol versions must be integers",
+                    next_action="send_integer_protocol_min_and_protocol_max",
+                )
+                await self._send_to_client(client_id, {
+                    "type": "hello_error",
+                    "code": error.code.value,
+                    "message": error.message,
+                    "next": error.next_action,
+                })
+                return
+
+            if client_min > client_max:
+                error = HelloError(
+                    code=ConnectErrorCode.PROTOCOL_UNSUPPORTED,
+                    message="protocol_min cannot be greater than protocol_max",
+                    next_action="fix_protocol_version_range",
+                )
+                await self._send_to_client(client_id, {
+                    "type": "hello_error",
+                    "code": error.code.value,
+                    "message": error.message,
+                    "next": error.next_action,
+                })
+                return
@@
-            client_caps = data.get("caps", [])
+            client_caps = data.get("capabilities", data.get("caps", []))
+            if not isinstance(client_caps, list):
+                client_caps = []
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/praisonai/praisonai/gateway/server.py` around lines 981 - 1024, The wire
parser is reading incorrect field names from the incoming data, causing clients
that properly serialize HelloParams to fall back to protocol v1 with no
capabilities. Update the parsing logic around line 982-983 to read protocol_min
and protocol_max as direct fields (not nested under protocol), and update line
1023 to read capabilities instead of caps, so that client data aligned with the
HelloParams contract is correctly parsed.

# Resume or create session
session_id = data.get("session_id")
since_cursor = data.get("since")
session, replay_events = self.resume_or_create_session(
session_id=session_id,
agent_id=agent_id,
client_id=client_id,
since_cursor=since_cursor,
)

# Validate session belongs to requested agent
if hasattr(session, 'agent_id') and session.agent_id != agent_id:
error = HelloError(
code=ConnectErrorCode.AUTH_UNAUTHORIZED,
message="Session does not belong to the requested agent",
next_action="start_new_session"
)
await self._send_to_client(client_id, {
"type": "hello_error",
"code": error.code.value,
"message": error.message,
"next": error.next_action,
})
return

# Rebind client_id to session for correct routing
if hasattr(session, '_client_id'):
session._client_id = client_id

self._client_sessions[client_id] = session.session_id
Comment on lines +1036 to +1065

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Validate and rebind resumed sessions before attaching this client.

resume_or_create_session() may return an existing session for a different agent_id and with an old client_id; after Line 1035, later responses use session.client_id, so resumed messages can be routed to the wrong websocket or expose another agent session.

🔒 Suggested guard before replay/attachment
             session, replay_events = self.resume_or_create_session(
                 session_id=session_id,
                 agent_id=agent_id,
                 client_id=client_id,
                 since_cursor=since_cursor,
             )
+
+            if session.agent_id != agent_id:
+                error = HelloError(
+                    code=ConnectErrorCode.AUTH_UNAUTHORIZED,
+                    message="Session does not belong to the requested agent",
+                    next_action="start_new_session",
+                )
+                await self._send_to_client(client_id, {
+                    "type": "hello_error",
+                    "code": error.code.value,
+                    "message": error.message,
+                    "next": error.next_action,
+                })
+                return
+
+            # Prefer moving this into `resume_or_create_session()` or a
+            # `GatewaySession.rebind_client()` helper so the legacy join path
+            # gets the same protection.
+            session._client_id = client_id
             
             self._client_sessions[client_id] = session.session_id
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Resume or create session
session_id = data.get("session_id")
since_cursor = data.get("since")
session, replay_events = self.resume_or_create_session(
session_id=session_id,
agent_id=agent_id,
client_id=client_id,
since_cursor=since_cursor,
)
self._client_sessions[client_id] = session.session_id
# Resume or create session
session_id = data.get("session_id")
since_cursor = data.get("since")
session, replay_events = self.resume_or_create_session(
session_id=session_id,
agent_id=agent_id,
client_id=client_id,
since_cursor=since_cursor,
)
if session.agent_id != agent_id:
error = HelloError(
code=ConnectErrorCode.AUTH_UNAUTHORIZED,
message="Session does not belong to the requested agent",
next_action="start_new_session",
)
await self._send_to_client(client_id, {
"type": "hello_error",
"code": error.code.value,
"message": error.message,
"next": error.next_action,
})
return
# Prefer moving this into `resume_or_create_session()` or a
# `GatewaySession.rebind_client()` helper so the legacy join path
# gets the same protection.
session._client_id = client_id
self._client_sessions[client_id] = session.session_id
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/praisonai/praisonai/gateway/server.py` around lines 1025 - 1035, The
resumed session returned from resume_or_create_session() may belong to a
different agent_id or contain a stale client_id, which causes incorrect message
routing later since responses use session.client_id. Before storing the session
in self._client_sessions, validate that the returned session's agent_id matches
the requested agent_id parameter, and rebind the current client_id to the
session to ensure the correct websocket receives the replayed messages and
subsequent responses.


# Build features list - only advertise implemented features
features = {
"methods": ["message", "leave"], # abort not implemented
"events": [
EventType.MESSAGE.value,
EventType.ERROR.value,
],
}

# Add streaming events if client supports streaming
if "streaming" in client_caps:
features["events"].extend([
EventType.TOKEN_STREAM.value,
EventType.TOOL_CALL_STREAM.value,
EventType.STREAM_END.value,
])

# Add optional features based on client capabilities
if "presence" in client_caps and hasattr(self, '_presence_tracker') and self._presence_tracker:
features["events"].extend([
EventType.PRESENCE_JOIN.value,
EventType.PRESENCE_LEAVE.value,
EventType.PRESENCE_UPDATE.value,
])

if "ack" in client_caps and hasattr(self, '_delivery_tracker') and self._delivery_tracker:
features["events"].extend([
EventType.MESSAGE_ACK.value,
EventType.MESSAGE_NACK.value,
EventType.DELIVERY_RETRY.value,
])

# Build policy limits - use configured values where available
heartbeat_interval = getattr(self.config, 'heartbeat_interval', 30)
policy = {
"max_payload": getattr(self.config, 'max_payload', 1048576), # 1MB default
"max_buffered_bytes": getattr(self.config, 'max_buffered_bytes', 8388608), # 8MB default
"heartbeat_ms": int(heartbeat_interval * 1000), # Convert seconds to ms
}

# Send successful handshake response
result = HelloResult(
protocol=negotiated_version,
features=features,
policy=policy,
session_id=session.session_id,
resumed=session._was_resumed,
cursor=session._event_cursor,
)

await self._send_to_client(client_id, {
"type": "hello_ok",
"protocol": result.protocol,
"features": result.features,
"policy": result.policy,
"session_id": result.session_id,
"resumed": result.resumed,
"cursor": result.cursor,
})

# Replay missed events if any
for event in replay_events:
await self._send_to_client(client_id, {
"type": "replay",
"event": event.to_dict(),
})

# Keep backward compatibility with old join message
elif msg_type == "join":
agent_id = data.get("agent_id")
if agent_id and agent_id in self._agents:
# Support reconnection with existing session
Expand All @@ -968,7 +1149,7 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) ->

self._client_sessions[client_id] = session.session_id

# Send join confirmation
# Send join confirmation (old format for backward compatibility)
await self._send_to_client(client_id, {
"type": "joined",
"session_id": session.session_id,
Expand Down
Loading