diff --git a/src/praisonai-agents/praisonaiagents/gateway/protocols.py b/src/praisonai-agents/praisonaiagents/gateway/protocols.py index b44f0f226..5ed0408ad 100644 --- a/src/praisonai-agents/praisonaiagents/gateway/protocols.py +++ b/src/praisonai-agents/praisonaiagents/gateway/protocols.py @@ -28,12 +28,25 @@ runtime_checkable, ) +# Gateway protocol versioning constants +GATEWAY_PROTOCOL_VERSION = 1 +MIN_CLIENT_PROTOCOL_VERSION = 1 + if TYPE_CHECKING: from praisonai.gateway.pairing import PairedChannel from ..agent import Agent from ..bots.presentation import MessagePresentation +class ConnectErrorCode(str, Enum): + """Structured error codes for connection failures.""" + AUTH_REQUIRED = "auth_required" + AUTH_UNAUTHORIZED = "auth_unauthorized" + PROTOCOL_UNSUPPORTED = "protocol_unsupported" + PAIRING_REQUIRED = "pairing_required" + AGENT_NOT_FOUND = "agent_not_found" + + class EventType(str, Enum): """Standard gateway event types.""" @@ -87,6 +100,65 @@ class EventType(str, Enum): # Polling events POLL_REQUEST = "poll_request" POLL_RESPONSE = "poll_response" + + # Handshake events + HELLO = "hello" + HELLO_OK = "hello_ok" + HELLO_ERROR = "hello_error" + + +@dataclass +class HelloParams: + """Parameters for initiating a versioned handshake. + + Attributes: + agent_id: The agent to connect to + protocol_min: Minimum protocol version the client supports + protocol_max: Maximum protocol version the client supports + capabilities: Optional list of capability tokens the client supports + session_id: Optional session to resume + since: Optional cursor for event replay + """ + agent_id: str + protocol_min: int + protocol_max: int + capabilities: List[str] = field(default_factory=list) + session_id: Optional[str] = None + since: Optional[int] = None + + +@dataclass +class HelloResult: + """Result of a successful handshake negotiation. + + Attributes: + protocol: The negotiated protocol version + features: Supported methods and events + policy: Gateway policy limits (max_payload, heartbeat_ms, etc.) + session_id: The session ID for this connection + resumed: Whether an existing session was resumed + cursor: Current event cursor position + """ + protocol: int + features: Dict[str, List[str]] # {"methods": [...], "events": [...]} + policy: Dict[str, int] # {"max_payload": ..., "heartbeat_ms": ...} + session_id: str + resumed: bool + cursor: int + + +@dataclass +class HelloError: + """Error response for failed handshake. + + Attributes: + code: Structured error code + message: Human-readable error message + next_action: Suggested next action (e.g., "upgrade_client", "pair_device") + """ + code: ConnectErrorCode + message: str + next_action: Optional[str] = None @dataclass diff --git a/src/praisonai/praisonai/gateway/server.py b/src/praisonai/praisonai/gateway/server.py index 28c551725..746b8eeb7 100644 --- a/src/praisonai/praisonai/gateway/server.py +++ b/src/praisonai/praisonai/gateway/server.py @@ -26,6 +26,13 @@ GatewayMessage, EventType, ) +from praisonaiagents.gateway.protocols import ( + ConnectErrorCode, + HelloResult, + HelloError, + GATEWAY_PROTOCOL_VERSION, + MIN_CLIENT_PROTOCOL_VERSION, +) from praisonaiagents.session.protocols import SessionStoreProtocol from praisonaiagents.session.store import DefaultSessionStore @@ -951,7 +958,181 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> """Handle a message from a client.""" msg_type = data.get("type", "message") - if msg_type == "join": + # Handle versioned handshake + if msg_type == "hello": + agent_id = data.get("agent_id") + + # Check if agent exists + if not agent_id or agent_id not in self._agents: + error = HelloError( + code=ConnectErrorCode.AGENT_NOT_FOUND, + message=f"Agent not found: {agent_id}", + next_action="check_agent_id" + ) + await self._send_to_client(client_id, { + "type": "hello_error", + "code": error.code.value, + "message": error.message, + "next": error.next_action, + }) + return + + # Parse protocol version from client + # Support both HelloParams format (protocol_min/max as direct fields) + # and legacy format (nested under protocol dict) + if "protocol_min" in data or "protocol_max" in data: + # HelloParams format + client_min = data.get("protocol_min", 1) + client_max = data.get("protocol_max", 1) + else: + # Legacy format or missing + protocol_info = data.get("protocol", {}) + if isinstance(protocol_info, dict): + client_min = protocol_info.get("min", 1) + client_max = protocol_info.get("max", 1) + else: + # Backwards compatibility: treat missing protocol as v1 + client_min = client_max = 1 + + # Negotiate protocol version + if client_max < MIN_CLIENT_PROTOCOL_VERSION: + error = HelloError( + code=ConnectErrorCode.PROTOCOL_UNSUPPORTED, + message=f"Protocol version {client_max} is too old, minimum required is {MIN_CLIENT_PROTOCOL_VERSION}", + next_action="upgrade_client" + ) + await self._send_to_client(client_id, { + "type": "hello_error", + "code": error.code.value, + "message": error.message, + "next": error.next_action, + }) + return + + if client_min > GATEWAY_PROTOCOL_VERSION: + error = HelloError( + code=ConnectErrorCode.PROTOCOL_UNSUPPORTED, + message=f"Protocol version {client_min} is too new, server supports up to {GATEWAY_PROTOCOL_VERSION}", + next_action="use_older_client" + ) + await self._send_to_client(client_id, { + "type": "hello_error", + "code": error.code.value, + "message": error.message, + "next": error.next_action, + }) + return + + # Select the highest mutually supported version + negotiated_version = min(client_max, GATEWAY_PROTOCOL_VERSION) + + # Get client capabilities + # Support both HelloParams format (capabilities) and legacy format (caps) + client_caps = data.get("capabilities", data.get("caps", [])) + # Guard against null/None values + if client_caps is None or not isinstance(client_caps, list): + client_caps = [] + + # Resume or create session + session_id = data.get("session_id") + since_cursor = data.get("since") + session, replay_events = self.resume_or_create_session( + session_id=session_id, + agent_id=agent_id, + client_id=client_id, + since_cursor=since_cursor, + ) + + # Validate session belongs to requested agent + if hasattr(session, 'agent_id') and session.agent_id != agent_id: + error = HelloError( + code=ConnectErrorCode.AUTH_UNAUTHORIZED, + message="Session does not belong to the requested agent", + next_action="start_new_session" + ) + await self._send_to_client(client_id, { + "type": "hello_error", + "code": error.code.value, + "message": error.message, + "next": error.next_action, + }) + return + + # Rebind client_id to session for correct routing + if hasattr(session, '_client_id'): + session._client_id = client_id + + self._client_sessions[client_id] = session.session_id + + # Build features list - only advertise implemented features + features = { + "methods": ["message", "leave"], # abort not implemented + "events": [ + EventType.MESSAGE.value, + EventType.ERROR.value, + ], + } + + # Add streaming events if client supports streaming + if "streaming" in client_caps: + features["events"].extend([ + EventType.TOKEN_STREAM.value, + EventType.TOOL_CALL_STREAM.value, + EventType.STREAM_END.value, + ]) + + # Add optional features based on client capabilities + if "presence" in client_caps and hasattr(self, '_presence_tracker') and self._presence_tracker: + features["events"].extend([ + EventType.PRESENCE_JOIN.value, + EventType.PRESENCE_LEAVE.value, + EventType.PRESENCE_UPDATE.value, + ]) + + if "ack" in client_caps and hasattr(self, '_delivery_tracker') and self._delivery_tracker: + features["events"].extend([ + EventType.MESSAGE_ACK.value, + EventType.MESSAGE_NACK.value, + EventType.DELIVERY_RETRY.value, + ]) + + # Build policy limits - use configured values where available + heartbeat_interval = getattr(self.config, 'heartbeat_interval', 30) + policy = { + "max_payload": getattr(self.config, 'max_payload', 1048576), # 1MB default + "max_buffered_bytes": getattr(self.config, 'max_buffered_bytes', 8388608), # 8MB default + "heartbeat_ms": int(heartbeat_interval * 1000), # Convert seconds to ms + } + + # Send successful handshake response + result = HelloResult( + protocol=negotiated_version, + features=features, + policy=policy, + session_id=session.session_id, + resumed=session._was_resumed, + cursor=session._event_cursor, + ) + + await self._send_to_client(client_id, { + "type": "hello_ok", + "protocol": result.protocol, + "features": result.features, + "policy": result.policy, + "session_id": result.session_id, + "resumed": result.resumed, + "cursor": result.cursor, + }) + + # Replay missed events if any + for event in replay_events: + await self._send_to_client(client_id, { + "type": "replay", + "event": event.to_dict(), + }) + + # Keep backward compatibility with old join message + elif msg_type == "join": agent_id = data.get("agent_id") if agent_id and agent_id in self._agents: # Support reconnection with existing session @@ -968,7 +1149,7 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> self._client_sessions[client_id] = session.session_id - # Send join confirmation + # Send join confirmation (old format for backward compatibility) await self._send_to_client(client_id, { "type": "joined", "session_id": session.session_id,