From bbcd27cd2558856c8c232c286e5975bf80637087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B6=85=E6=B8=A1=E6=B3=95=E5=B8=AB?= Date: Tue, 30 Jun 2026 03:50:42 +0000 Subject: [PATCH 1/3] feat(acp): implement ACP Server with WebSocket transport (Phase 1 MVP) Add an ACP-compliant server endpoint at GET /acp that accepts WebSocket connections and speaks JSON-RPC 2.0 per the Agent Client Protocol spec. Implements: - WebSocket upgrade with Bearer token auth (OPENAB_ACP_AUTH_KEY) - initialize: capability negotiation - session/new: create OAB session mapped to internal channel - session/prompt: dispatch to agent via GatewayEvent, stream back AgentMessageChunk notifications from GatewayReply - session/cancel: placeholder for Phase 2 Architecture: ACP Client --WS JSON-RPC--> /acp endpoint --> GatewayEvent GatewayReply --> AcpReplyRegistry --> AgentMessageChunk notification --> Client Feature-gated behind 'acp' feature flag. Enable with: OPENAB_ACP_ENABLED=true cargo run --features acp Refs: ADR PR #1258 --- Cargo.toml | 1 + crates/openab-gateway/Cargo.toml | 2 + .../openab-gateway/src/adapters/acp_server.rs | 561 ++++++++++++++++++ crates/openab-gateway/src/adapters/mod.rs | 2 + crates/openab-gateway/src/lib.rs | 41 ++ 5 files changed, 607 insertions(+) create mode 100644 crates/openab-gateway/src/adapters/acp_server.rs diff --git a/Cargo.toml b/Cargo.toml index 69b6f01ae..e50e92ae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ feishu = ["dep:openab-gateway", "dep:axum", "openab-gateway/feishu"] googlechat = ["dep:openab-gateway", "dep:axum", "openab-gateway/googlechat"] wecom = ["dep:openab-gateway", "dep:axum", "openab-gateway/wecom"] teams = ["dep:openab-gateway", "dep:axum", "openab-gateway/teams"] +acp = ["dep:openab-gateway", "dep:axum", "openab-gateway/acp"] [dev-dependencies] tempfile = "3.27.0" diff --git a/crates/openab-gateway/Cargo.toml b/crates/openab-gateway/Cargo.toml index 26ee00db9..f2bfbba5e 100644 --- a/crates/openab-gateway/Cargo.toml +++ b/crates/openab-gateway/Cargo.toml @@ -30,6 +30,7 @@ quick-xml = "0.37" image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } parking_lot = "0.12" urlencoding = "2" +agent-client-protocol = { version = "1.0", optional = true } [dev-dependencies] wiremock = "0.6" @@ -42,3 +43,4 @@ feishu = [] googlechat = [] wecom = [] teams = [] +acp = ["dep:agent-client-protocol"] diff --git a/crates/openab-gateway/src/adapters/acp_server.rs b/crates/openab-gateway/src/adapters/acp_server.rs new file mode 100644 index 000000000..3d6107114 --- /dev/null +++ b/crates/openab-gateway/src/adapters/acp_server.rs @@ -0,0 +1,561 @@ +//! ACP (Agent Client Protocol) Server adapter. +//! +//! Exposes OAB as an ACP-compliant server over WebSocket at `GET /acp`. +//! Any ACP client (Zed, JetBrains, desktop apps, web apps, CLIs) can connect +//! and interact with OAB's multi-agent platform using the standard protocol. +//! +//! Protocol flow: +//! Client connects via WebSocket → sends `initialize` → `session/new` → `session/prompt` +//! Server streams back `AgentMessageChunk` notifications, then the prompt response. +//! +//! Internally, prompts are converted to `GatewayEvent` and dispatched through OAB's +//! existing event pipeline. Replies (`GatewayReply`) are translated back into ACP +//! notifications and streamed to the client. + +use crate::schema::*; +use axum::extract::ws::{Message, WebSocket}; +use axum::extract::{Query, State, WebSocketUpgrade}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use futures_util::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tracing::{info, warn}; +use uuid::Uuid; + +// --------------------------------------------------------------------------- +// ACP Configuration +// --------------------------------------------------------------------------- + +pub struct AcpConfig { + pub auth_key: Option, +} + +impl AcpConfig { + pub fn from_env() -> Option { + let enabled = std::env::var("OPENAB_ACP_ENABLED") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + if !enabled { + return None; + } + let auth_key = std::env::var("OPENAB_ACP_AUTH_KEY").ok(); + if auth_key.is_none() { + warn!("OPENAB_ACP_AUTH_KEY not set — ACP endpoint is UNAUTHENTICATED"); + } + Some(Self { auth_key }) + } +} + +// --------------------------------------------------------------------------- +// ACP Session tracking +// --------------------------------------------------------------------------- + +/// Tracks an active ACP session and its reply channel. +struct AcpSession { + /// Channel ID used in GatewayEvent (maps replies back to this session) + channel_id: String, + /// Sender for streaming reply chunks back to the WebSocket handler + reply_tx: mpsc::UnboundedSender, +} + +pub enum ReplyChunk { + /// Incremental text snapshot (full text so far) + Text(String), + /// Agent finished responding + Done, +} + +/// Registry of active ACP sessions: channel_id → reply sender +pub type AcpReplyRegistry = Arc>>>; + +pub fn new_reply_registry() -> AcpReplyRegistry { + Arc::new(Mutex::new(HashMap::new())) +} + +// --------------------------------------------------------------------------- +// JSON-RPC types (minimal subset for ACP) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +struct JsonRpcRequest { + jsonrpc: String, + method: String, + #[serde(default)] + id: Option, + #[serde(default)] + params: Option, +} + +#[derive(Debug, Serialize)] +struct JsonRpcResponse { + jsonrpc: &'static str, + id: Value, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(Debug, Serialize)] +struct JsonRpcError { + code: i32, + message: String, +} + +#[derive(Debug, Serialize)] +struct JsonRpcNotification { + jsonrpc: &'static str, + method: String, + params: Value, +} + +impl JsonRpcResponse { + fn success(id: Value, result: Value) -> Self { + Self { + jsonrpc: "2.0", + id, + result: Some(result), + error: None, + } + } + + fn error(id: Value, code: i32, message: impl Into) -> Self { + Self { + jsonrpc: "2.0", + id, + result: None, + error: Some(JsonRpcError { + code, + message: message.into(), + }), + } + } +} + +// --------------------------------------------------------------------------- +// WebSocket upgrade handler: GET /acp +// --------------------------------------------------------------------------- + +pub async fn ws_upgrade( + State(state): State>, + query: Query>, + headers: axum::http::HeaderMap, + ws: WebSocketUpgrade, +) -> axum::response::Response { + // Auth: Bearer token from Authorization header or ?token= query param + let token = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .or_else(|| query.get("token").map(|s| s.as_str())); + + let expected = state.acp.as_ref().and_then(|c| c.auth_key.as_ref()); + if let Some(expected) = expected { + if token != Some(expected.as_str()) { + warn!("ACP WebSocket rejected: invalid or missing token"); + return StatusCode::UNAUTHORIZED.into_response(); + } + } + + ws.on_upgrade(move |socket| handle_acp_connection(state, socket)) +} + +// --------------------------------------------------------------------------- +// ACP Connection handler +// --------------------------------------------------------------------------- + +async fn handle_acp_connection(state: Arc, socket: WebSocket) { + let (mut ws_tx, mut ws_rx) = socket.split(); + let connection_id = format!("acp_conn_{}", Uuid::new_v4()); + + info!(connection = %connection_id, "ACP client connected"); + + // Session state for this connection + let sessions: Arc>> = Arc::new(Mutex::new(HashMap::new())); + let mut initialized = false; + + // Channel for sending messages back to the client + let (out_tx, mut out_rx) = mpsc::unbounded_channel::(); + + // Forward outbound messages to WebSocket + let send_task = tokio::spawn(async move { + while let Some(msg) = out_rx.recv().await { + if ws_tx.send(Message::Text(msg.into())).await.is_err() { + break; + } + } + }); + + // Process incoming messages + while let Some(Ok(msg)) = ws_rx.next().await { + let Message::Text(text) = msg else { + continue; + }; + + let req: JsonRpcRequest = match serde_json::from_str(&text) { + Ok(r) => r, + Err(e) => { + let err_resp = JsonRpcResponse::error( + Value::Null, + -32700, + format!("Parse error: {e}"), + ); + let _ = out_tx.send(serde_json::to_string(&err_resp).unwrap()); + continue; + } + }; + + let id = req.id.clone().unwrap_or(Value::Null); + + match req.method.as_str() { + "initialize" => { + let resp = handle_initialize(&connection_id, &req); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + initialized = true; + } + "session/new" => { + if !initialized { + let resp = JsonRpcResponse::error(id, -32002, "Not initialized"); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + continue; + } + let resp = handle_session_new( + &state, + &sessions, + id.clone(), + req.params.as_ref(), + ) + .await; + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + } + "session/prompt" => { + if !initialized { + let resp = JsonRpcResponse::error(id, -32002, "Not initialized"); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + continue; + } + // session/prompt is async — spawn a task to handle streaming + let state_clone = state.clone(); + let sessions_clone = sessions.clone(); + let out_tx_clone = out_tx.clone(); + tokio::spawn(async move { + handle_session_prompt( + &state_clone, + &sessions_clone, + id, + req.params.as_ref(), + &out_tx_clone, + ) + .await; + }); + } + "session/cancel" => { + // TODO: implement cancellation in Phase 2 + let resp = JsonRpcResponse::success(id, json!({})); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + } + _ => { + let resp = JsonRpcResponse::error( + id, + -32601, + format!("Method not found: {}", req.method), + ); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + } + } + } + + // Cleanup: remove all sessions for this connection from the reply registry + if let Some(ref registry) = state.acp_reply_registry { + let sessions_guard = sessions.lock().await; + let mut reg = registry.lock().await; + for (_, session) in sessions_guard.iter() { + reg.remove(&session.channel_id); + } + } + + send_task.abort(); + info!(connection = %connection_id, "ACP client disconnected"); +} + +// --------------------------------------------------------------------------- +// Method handlers +// --------------------------------------------------------------------------- + +fn handle_initialize(connection_id: &str, _req: &JsonRpcRequest) -> JsonRpcResponse { + let id = _req.id.clone().unwrap_or(Value::Null); + JsonRpcResponse::success( + id, + json!({ + "protocolVersion": "v1", + "connectionId": connection_id, + "agentCapabilities": { + "streaming": true, + "promptCapabilities": { + "image": false, + "audio": false, + "embeddedContext": false + } + }, + "agentInfo": { + "name": "openab", + "version": env!("CARGO_PKG_VERSION") + } + }), + ) +} + +async fn handle_session_new( + state: &Arc, + sessions: &Arc>>, + id: Value, + _params: Option<&Value>, +) -> JsonRpcResponse { + let session_id = format!("sess_{}", Uuid::new_v4()); + let channel_id = format!("acp_{}", Uuid::new_v4()); + + // Create reply channel for this session + let (reply_tx, _reply_rx) = mpsc::unbounded_channel::(); + + // Register in the global reply registry so handle_reply can find it + if let Some(ref registry) = state.acp_reply_registry { + registry.lock().await.insert(channel_id.clone(), reply_tx.clone()); + } + + // Store session locally + sessions.lock().await.insert( + session_id.clone(), + AcpSession { + channel_id, + reply_tx, + }, + ); + + info!(session = %session_id, "ACP session created"); + + JsonRpcResponse::success( + id, + json!({ + "sessionId": session_id, + "models": { + "current": "openab", + "available": [ + {"id": "openab", "name": "OpenAB Default Agent"} + ] + } + }), + ) +} + +async fn handle_session_prompt( + state: &Arc, + sessions: &Arc>>, + id: Value, + params: Option<&Value>, + out_tx: &mpsc::UnboundedSender, +) { + // Extract sessionId and prompt from params + let (session_id, prompt_text) = match extract_prompt_params(params) { + Ok(v) => v, + Err(e) => { + let resp = JsonRpcResponse::error(id, -32602, e); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + return; + } + }; + + // Look up session + let (channel_id, reply_tx) = { + let guard = sessions.lock().await; + match guard.get(&session_id) { + Some(s) => (s.channel_id.clone(), s.reply_tx.clone()), + None => { + let resp = JsonRpcResponse::error( + id, + -32602, + format!("Unknown session: {session_id}"), + ); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + return; + } + } + }; + + // Create a new reply receiver (re-register to get fresh channel) + let (new_tx, mut reply_rx) = mpsc::unbounded_channel::(); + if let Some(ref registry) = state.acp_reply_registry { + registry.lock().await.insert(channel_id.clone(), new_tx.clone()); + } + // Update session's reply_tx + { + let mut guard = sessions.lock().await; + if let Some(s) = guard.get_mut(&session_id) { + s.reply_tx = new_tx; + } + } + + // Convert to GatewayEvent and dispatch + let event = GatewayEvent::new( + "acp", + ChannelInfo { + id: channel_id.clone(), + channel_type: "dm".into(), + thread_id: None, + }, + SenderInfo { + id: "acp_client".into(), + name: "acp_client".into(), + display_name: "ACP Client".into(), + is_bot: false, + }, + &prompt_text, + &format!("acpmsg_{}", Uuid::new_v4()), + Vec::new(), + ); + + // Send event through the broadcast channel + match serde_json::to_string(&event) { + Ok(json) => { + let _ = state.event_tx.send(json); + } + Err(e) => { + warn!("ACP: failed to serialize event: {e}"); + let resp = JsonRpcResponse::error(id, -32603, "Internal error"); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + return; + } + } + + info!(session = %session_id, channel = %channel_id, "ACP: prompt dispatched"); + + // Stream replies back as ACP notifications + let mut sent_len = 0usize; + let timeout = tokio::time::Duration::from_secs(180); + + loop { + match tokio::time::timeout(timeout, reply_rx.recv()).await { + Ok(Some(ReplyChunk::Text(full_text))) => { + // Send delta as AgentMessageChunk notification + let delta = if full_text.len() > sent_len { + &full_text[sent_len..] + } else { + continue; + }; + sent_len = full_text.len(); + + let notification = JsonRpcNotification { + jsonrpc: "2.0", + method: "session/notification".into(), + params: json!({ + "sessionId": session_id, + "update": { + "sessionUpdate": "agentMessageChunk", + "chunk": { + "content": {"type": "text", "text": delta} + } + } + }), + }; + let _ = out_tx.send(serde_json::to_string(¬ification).unwrap()); + } + Ok(Some(ReplyChunk::Done)) | Ok(None) => { + break; + } + Err(_) => { + warn!(session = %session_id, "ACP: prompt timed out waiting for reply"); + break; + } + } + } + + // Send the final response + let resp = JsonRpcResponse::success( + id, + json!({ + "sessionId": session_id, + "stopReason": "endTurn" + }), + ); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); +} + +fn extract_prompt_params(params: Option<&Value>) -> Result<(String, String), String> { + let params = params.ok_or("Missing params")?; + let session_id = params + .get("sessionId") + .and_then(|v| v.as_str()) + .ok_or("Missing sessionId")? + .to_string(); + let prompt = params.get("prompt").ok_or("Missing prompt")?; + + // Prompt can be an array of content blocks or a simple string + let text = if let Some(arr) = prompt.as_array() { + arr.iter() + .filter_map(|block| { + if block.get("type").and_then(|t| t.as_str()) == Some("text") { + block.get("text").and_then(|t| t.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n") + } else if let Some(s) = prompt.as_str() { + s.to_string() + } else { + return Err("Invalid prompt format".into()); + }; + + if text.trim().is_empty() { + return Err("Empty prompt".into()); + } + + Ok((session_id, text)) +} + +// --------------------------------------------------------------------------- +// Reply handler: called when GatewayReply arrives for an ACP session +// --------------------------------------------------------------------------- + +/// Process a GatewayReply destined for an ACP session. +/// Called from the unified bridge's reply dispatch logic. +pub async fn handle_reply(reply: &GatewayReply, registry: &AcpReplyRegistry) { + let key = reply.channel.id.as_str(); + if !key.starts_with("acp_") { + return; + } + + let full_text = reply.content.text.clone(); + // Skip placeholder/draft messages + if full_text == "…" || full_text == "draft" { + return; + } + + let mut map = registry.lock().await; + let Some(tx) = map.get(key) else { + return; + }; + + match reply.command.as_deref() { + Some("edit_message") => { + // Streaming update — send as text snapshot + if tx.send(ReplyChunk::Text(full_text)).is_err() { + map.remove(key); + } + } + None | Some("send_message") => { + // Final message + let _ = tx.send(ReplyChunk::Text(full_text)); + let _ = tx.send(ReplyChunk::Done); + map.remove(key); + } + Some("add_reaction") | Some("remove_reaction") => { + // Reactions are agent state indicators — could map to notifications later + } + _ => {} + } +} diff --git a/crates/openab-gateway/src/adapters/mod.rs b/crates/openab-gateway/src/adapters/mod.rs index f58f870a2..ac7867766 100644 --- a/crates/openab-gateway/src/adapters/mod.rs +++ b/crates/openab-gateway/src/adapters/mod.rs @@ -12,3 +12,5 @@ pub mod googlechat; pub mod wecom; #[cfg(feature = "teams")] pub mod teams; +#[cfg(feature = "acp")] +pub mod acp_server; diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs index 92317175b..7bd7b03d8 100644 --- a/crates/openab-gateway/src/lib.rs +++ b/crates/openab-gateway/src/lib.rs @@ -40,6 +40,10 @@ pub struct AppState { pub google_chat: Option, #[cfg(feature = "wecom")] pub wecom: Option, + #[cfg(feature = "acp")] + pub acp: Option, + #[cfg(feature = "acp")] + pub acp_reply_registry: Option, pub ws_token: Option, pub event_tx: broadcast::Sender, pub reply_token_cache: ReplyTokenCache, @@ -75,6 +79,10 @@ impl AppState { google_chat: None, #[cfg(feature = "wecom")] wecom: None, + #[cfg(feature = "acp")] + acp: None, + #[cfg(feature = "acp")] + acp_reply_registry: None, ws_token: None, event_tx, reply_token_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), @@ -156,6 +164,12 @@ impl AppState { let wecom = adapters::wecom::WecomConfig::from_env() .map(adapters::wecom::WecomAdapter::new); + // ACP Server + #[cfg(feature = "acp")] + let acp = adapters::acp_server::AcpConfig::from_env(); + #[cfg(feature = "acp")] + let acp_reply_registry = acp.as_ref().map(|_| adapters::acp_server::new_reply_registry()); + let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build() @@ -177,6 +191,10 @@ impl AppState { google_chat, #[cfg(feature = "wecom")] wecom, + #[cfg(feature = "acp")] + acp, + #[cfg(feature = "acp")] + acp_reply_registry, ws_token, event_tx, reply_token_cache: Arc::new(std::sync::Mutex::new(HashMap::new())), @@ -222,6 +240,13 @@ pub async fn serve(config: ServeConfig) -> anyhow::Result<()> { .route("/ws", get(ws_handler)) .route("/health", get(health)); + // ACP Server adapter + #[cfg(feature = "acp")] + if std::env::var("OPENAB_ACP_ENABLED").map(|v| v == "true" || v == "1").unwrap_or(false) { + info!("ACP server endpoint enabled at /acp"); + app = app.route("/acp", get(adapters::acp_server::ws_upgrade)); + } + // Telegram adapter #[cfg(feature = "telegram")] let telegram_bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok(); @@ -413,6 +438,16 @@ pub async fn serve(config: ServeConfig) -> anyhow::Result<()> { google_chat, #[cfg(feature = "wecom")] wecom, + #[cfg(feature = "acp")] + acp: adapters::acp_server::AcpConfig::from_env(), + #[cfg(feature = "acp")] + acp_reply_registry: { + if std::env::var("OPENAB_ACP_ENABLED").map(|v| v == "true" || v == "1").unwrap_or(false) { + Some(adapters::acp_server::new_reply_registry()) + } else { + None + } + }, ws_token, event_tx, reply_token_cache, @@ -620,6 +655,12 @@ async fn handle_oab_connection(state: Arc, socket: axum::extract::ws:: warn!("reply for wecom but adapter not configured"); } } + #[cfg(feature = "acp")] + "acp" => { + if let Some(ref registry) = state_for_recv.acp_reply_registry { + adapters::acp_server::handle_reply(&reply, registry).await; + } + } other => warn!(platform = other, "unknown reply platform"), } } From 453d9958462c07269fc09fd660272d79a058cc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B6=85=E6=B8=A1=E6=B3=95=E5=B8=AB?= Date: Tue, 30 Jun 2026 04:03:54 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix(acp):=20address=20review=20feedback=20f?= =?UTF-8?q?rom=20=E6=A0=B8=E6=B8=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit F1: Use subtle::ConstantTimeEq for auth token comparison (timing attack) F2: Remove duplicate AcpConfig/registry construction in serve() F3: Validate jsonrpc == "2.0" per spec (reject with -32600) F4: Defer reply channel creation to session/prompt (no dead _reply_rx) + Add tracing::debug for reply send failure (race condition visibility) --- .../openab-gateway/src/adapters/acp_server.rs | 54 ++++++++++--------- crates/openab-gateway/src/lib.rs | 9 +--- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/crates/openab-gateway/src/adapters/acp_server.rs b/crates/openab-gateway/src/adapters/acp_server.rs index 3d6107114..f853ec362 100644 --- a/crates/openab-gateway/src/adapters/acp_server.rs +++ b/crates/openab-gateway/src/adapters/acp_server.rs @@ -58,8 +58,6 @@ impl AcpConfig { struct AcpSession { /// Channel ID used in GatewayEvent (maps replies back to this session) channel_id: String, - /// Sender for streaming reply chunks back to the WebSocket handler - reply_tx: mpsc::UnboundedSender, } pub enum ReplyChunk { @@ -155,7 +153,15 @@ pub async fn ws_upgrade( let expected = state.acp.as_ref().and_then(|c| c.auth_key.as_ref()); if let Some(expected) = expected { - if token != Some(expected.as_str()) { + let valid = match token { + Some(t) => { + // Constant-time comparison to prevent timing attacks + use subtle::ConstantTimeEq; + t.as_bytes().ct_eq(expected.as_bytes()).into() + } + None => false, + }; + if !valid { warn!("ACP WebSocket rejected: invalid or missing token"); return StatusCode::UNAUTHORIZED.into_response(); } @@ -209,6 +215,17 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { } }; + // Validate JSON-RPC version (spec requires "2.0") + if req.jsonrpc != "2.0" { + let err_resp = JsonRpcResponse::error( + req.id.clone().unwrap_or(Value::Null), + -32600, + "Invalid Request: jsonrpc must be \"2.0\"", + ); + let _ = out_tx.send(serde_json::to_string(&err_resp).unwrap()); + continue; + } + let id = req.id.clone().unwrap_or(Value::Null); match req.method.as_str() { @@ -310,7 +327,7 @@ fn handle_initialize(connection_id: &str, _req: &JsonRpcRequest) -> JsonRpcRespo } async fn handle_session_new( - state: &Arc, + _state: &Arc, sessions: &Arc>>, id: Value, _params: Option<&Value>, @@ -318,20 +335,11 @@ async fn handle_session_new( let session_id = format!("sess_{}", Uuid::new_v4()); let channel_id = format!("acp_{}", Uuid::new_v4()); - // Create reply channel for this session - let (reply_tx, _reply_rx) = mpsc::unbounded_channel::(); - - // Register in the global reply registry so handle_reply can find it - if let Some(ref registry) = state.acp_reply_registry { - registry.lock().await.insert(channel_id.clone(), reply_tx.clone()); - } - - // Store session locally + // Store session locally (reply channel is created lazily in session/prompt) sessions.lock().await.insert( session_id.clone(), AcpSession { channel_id, - reply_tx, }, ); @@ -369,10 +377,10 @@ async fn handle_session_prompt( }; // Look up session - let (channel_id, reply_tx) = { + let channel_id = { let guard = sessions.lock().await; match guard.get(&session_id) { - Some(s) => (s.channel_id.clone(), s.reply_tx.clone()), + Some(s) => s.channel_id.clone(), None => { let resp = JsonRpcResponse::error( id, @@ -385,17 +393,10 @@ async fn handle_session_prompt( } }; - // Create a new reply receiver (re-register to get fresh channel) - let (new_tx, mut reply_rx) = mpsc::unbounded_channel::(); + // Create reply channel for this prompt + let (reply_tx, mut reply_rx) = mpsc::unbounded_channel::(); if let Some(ref registry) = state.acp_reply_registry { - registry.lock().await.insert(channel_id.clone(), new_tx.clone()); - } - // Update session's reply_tx - { - let mut guard = sessions.lock().await; - if let Some(s) = guard.get_mut(&session_id) { - s.reply_tx = new_tx; - } + registry.lock().await.insert(channel_id.clone(), reply_tx); } // Convert to GatewayEvent and dispatch @@ -544,6 +545,7 @@ pub async fn handle_reply(reply: &GatewayReply, registry: &AcpReplyRegistry) { Some("edit_message") => { // Streaming update — send as text snapshot if tx.send(ReplyChunk::Text(full_text)).is_err() { + tracing::debug!(channel = key, "ACP reply send failed (client likely disconnected)"); map.remove(key); } } diff --git a/crates/openab-gateway/src/lib.rs b/crates/openab-gateway/src/lib.rs index 7bd7b03d8..85873e096 100644 --- a/crates/openab-gateway/src/lib.rs +++ b/crates/openab-gateway/src/lib.rs @@ -441,13 +441,8 @@ pub async fn serve(config: ServeConfig) -> anyhow::Result<()> { #[cfg(feature = "acp")] acp: adapters::acp_server::AcpConfig::from_env(), #[cfg(feature = "acp")] - acp_reply_registry: { - if std::env::var("OPENAB_ACP_ENABLED").map(|v| v == "true" || v == "1").unwrap_or(false) { - Some(adapters::acp_server::new_reply_registry()) - } else { - None - } - }, + acp_reply_registry: adapters::acp_server::AcpConfig::from_env() + .map(|_| adapters::acp_server::new_reply_registry()), ws_token, event_tx, reply_token_cache, From 83fc4613e18b0358e759be092ff38869d7971950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B6=85=E6=B8=A1=E6=B3=95=E5=B8=AB?= Date: Tue, 30 Jun 2026 04:06:12 +0000 Subject: [PATCH 3/3] =?UTF-8?q?fix(acp):=20address=20review=20feedback=20f?= =?UTF-8?q?rom=20=E6=93=BA=E6=B8=A1=20and=20Z=E6=B8=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical fixes: - 🔴 Reject concurrent prompts per session (-32001 'Session busy') - 🔴 Abort prompt tasks on disconnect to prevent registry leaks - 🔴 Clean up registry entries after prompt completes (not just on Done) Important fixes: - 🟡 Check event_tx.send() result — return JSON-RPC error if no agent connected - 🟡 Remove unused agent-client-protocol dep (Phase 1 is manual JSON-RPC) - 🟡 Add 'acp' to unified feature list in root Cargo.toml - 🟡 Switch AcpReplyRegistry from tokio::sync::Mutex to std::sync::Mutex (operations are CPU-bound, never held across .await) Also: - Rewrite acp_server.rs for clarity after multiple incremental fixes - Add busy flag per session to track in-flight prompts - Prompt handler now properly cleans up registry on all exit paths --- Cargo.toml | 2 +- crates/openab-gateway/Cargo.toml | 3 +- .../openab-gateway/src/adapters/acp_server.rs | 142 +++++++++++++----- 3 files changed, 108 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e50e92ae8..fa29d3eda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ serenity = { version = "0.12", default-features = false, features = ["client", " default = ["discord", "slack", "secrets-aws", "agentcore", "config-s3", "pre-seed"] # Opt-in: compile all gateway adapters into a single unified binary -unified = ["telegram", "line", "feishu", "googlechat", "wecom", "teams"] +unified = ["telegram", "line", "feishu", "googlechat", "wecom", "teams", "acp"] # Core adapters discord = ["dep:serenity", "openab-core/discord"] diff --git a/crates/openab-gateway/Cargo.toml b/crates/openab-gateway/Cargo.toml index f2bfbba5e..8cd91e0ff 100644 --- a/crates/openab-gateway/Cargo.toml +++ b/crates/openab-gateway/Cargo.toml @@ -30,7 +30,6 @@ quick-xml = "0.37" image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } parking_lot = "0.12" urlencoding = "2" -agent-client-protocol = { version = "1.0", optional = true } [dev-dependencies] wiremock = "0.6" @@ -43,4 +42,4 @@ feishu = [] googlechat = [] wecom = [] teams = [] -acp = ["dep:agent-client-protocol"] +acp = [] diff --git a/crates/openab-gateway/src/adapters/acp_server.rs b/crates/openab-gateway/src/adapters/acp_server.rs index f853ec362..69990f56c 100644 --- a/crates/openab-gateway/src/adapters/acp_server.rs +++ b/crates/openab-gateway/src/adapters/acp_server.rs @@ -22,8 +22,8 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tracing::{info, warn}; +use tokio::sync::mpsc; +use tracing::{debug, info, warn}; use uuid::Uuid; // --------------------------------------------------------------------------- @@ -54,10 +54,12 @@ impl AcpConfig { // ACP Session tracking // --------------------------------------------------------------------------- -/// Tracks an active ACP session and its reply channel. +/// Tracks an active ACP session. struct AcpSession { /// Channel ID used in GatewayEvent (maps replies back to this session) channel_id: String, + /// Whether a prompt is currently in-flight for this session + busy: bool, } pub enum ReplyChunk { @@ -67,11 +69,13 @@ pub enum ReplyChunk { Done, } -/// Registry of active ACP sessions: channel_id → reply sender -pub type AcpReplyRegistry = Arc>>>; +/// Registry of active ACP sessions: channel_id → reply sender. +/// Uses std::sync::Mutex because all operations are fast CPU-bound +/// (insert/remove/get) and never hold the lock across .await. +pub type AcpReplyRegistry = Arc>>>; pub fn new_reply_registry() -> AcpReplyRegistry { - Arc::new(Mutex::new(HashMap::new())) + Arc::new(std::sync::Mutex::new(HashMap::new())) } // --------------------------------------------------------------------------- @@ -181,9 +185,13 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { info!(connection = %connection_id, "ACP client connected"); // Session state for this connection - let sessions: Arc>> = Arc::new(Mutex::new(HashMap::new())); + let sessions: Arc>> = + Arc::new(tokio::sync::Mutex::new(HashMap::new())); let mut initialized = false; + // Track spawned prompt tasks so we can abort on disconnect + let mut prompt_tasks: Vec> = Vec::new(); + // Channel for sending messages back to the client let (out_tx, mut out_rx) = mpsc::unbounded_channel::(); @@ -240,13 +248,7 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); continue; } - let resp = handle_session_new( - &state, - &sessions, - id.clone(), - req.params.as_ref(), - ) - .await; + let resp = handle_session_new(&sessions, id.clone()).await; let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); } "session/prompt" => { @@ -259,7 +261,7 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { let state_clone = state.clone(); let sessions_clone = sessions.clone(); let out_tx_clone = out_tx.clone(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { handle_session_prompt( &state_clone, &sessions_clone, @@ -269,6 +271,7 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { ) .await; }); + prompt_tasks.push(handle); } "session/cancel" => { // TODO: implement cancellation in Phase 2 @@ -284,15 +287,35 @@ async fn handle_acp_connection(state: Arc, socket: WebSocket) { let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); } } + + // Clean up finished tasks + prompt_tasks.retain(|h| !h.is_finished()); + } + + // --- Disconnect cleanup --- + // Abort any in-flight prompt tasks to prevent registry leaks + for handle in prompt_tasks { + handle.abort(); } - // Cleanup: remove all sessions for this connection from the reply registry + // Remove all sessions for this connection from the reply registry if let Some(ref registry) = state.acp_reply_registry { let sessions_guard = sessions.lock().await; - let mut reg = registry.lock().await; - for (_, session) in sessions_guard.iter() { - reg.remove(&session.channel_id); + let channel_ids: Vec = sessions_guard + .values() + .map(|s| s.channel_id.clone()) + .collect(); + drop(sessions_guard); + + let mut reg = registry.lock().unwrap_or_else(|e| e.into_inner()); + for cid in &channel_ids { + reg.remove(cid); } + debug!( + connection = %connection_id, + sessions_cleaned = channel_ids.len(), + "ACP connection cleanup complete" + ); } send_task.abort(); @@ -327,10 +350,8 @@ fn handle_initialize(connection_id: &str, _req: &JsonRpcRequest) -> JsonRpcRespo } async fn handle_session_new( - _state: &Arc, - sessions: &Arc>>, + sessions: &Arc>>, id: Value, - _params: Option<&Value>, ) -> JsonRpcResponse { let session_id = format!("sess_{}", Uuid::new_v4()); let channel_id = format!("acp_{}", Uuid::new_v4()); @@ -340,6 +361,7 @@ async fn handle_session_new( session_id.clone(), AcpSession { channel_id, + busy: false, }, ); @@ -361,7 +383,7 @@ async fn handle_session_new( async fn handle_session_prompt( state: &Arc, - sessions: &Arc>>, + sessions: &Arc>>, id: Value, params: Option<&Value>, out_tx: &mpsc::UnboundedSender, @@ -376,11 +398,24 @@ async fn handle_session_prompt( } }; - // Look up session + // Look up session and acquire busy lock let channel_id = { - let guard = sessions.lock().await; - match guard.get(&session_id) { - Some(s) => s.channel_id.clone(), + let mut guard = sessions.lock().await; + match guard.get_mut(&session_id) { + Some(s) => { + if s.busy { + // Reject concurrent prompts on the same session + let resp = JsonRpcResponse::error( + id, + -32001, + "Session busy: a prompt is already in progress", + ); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + return; + } + s.busy = true; + s.channel_id.clone() + } None => { let resp = JsonRpcResponse::error( id, @@ -393,10 +428,13 @@ async fn handle_session_prompt( } }; - // Create reply channel for this prompt + // Create reply channel for this prompt and register it let (reply_tx, mut reply_rx) = mpsc::unbounded_channel::(); if let Some(ref registry) = state.acp_reply_registry { - registry.lock().await.insert(channel_id.clone(), reply_tx); + registry + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(channel_id.clone(), reply_tx); } // Convert to GatewayEvent and dispatch @@ -421,12 +459,33 @@ async fn handle_session_prompt( // Send event through the broadcast channel match serde_json::to_string(&event) { Ok(json) => { - let _ = state.event_tx.send(json); + if state.event_tx.send(json).is_err() { + // No receivers — agent/core not connected + warn!("ACP: event_tx send failed — no agent connected"); + let resp = JsonRpcResponse::error( + id, + -32603, + "No agent backend connected", + ); + let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + // Release busy flag + if let Some(s) = sessions.lock().await.get_mut(&session_id) { + s.busy = false; + } + // Cleanup registry + if let Some(ref registry) = state.acp_reply_registry { + registry.lock().unwrap_or_else(|e| e.into_inner()).remove(&channel_id); + } + return; + } } Err(e) => { warn!("ACP: failed to serialize event: {e}"); let resp = JsonRpcResponse::error(id, -32603, "Internal error"); let _ = out_tx.send(serde_json::to_string(&resp).unwrap()); + if let Some(s) = sessions.lock().await.get_mut(&session_id) { + s.busy = false; + } return; } } @@ -473,6 +532,14 @@ async fn handle_session_prompt( } } + // Cleanup: remove from registry and release busy flag + if let Some(ref registry) = state.acp_reply_registry { + registry.lock().unwrap_or_else(|e| e.into_inner()).remove(&channel_id); + } + if let Some(s) = sessions.lock().await.get_mut(&session_id) { + s.busy = false; + } + // Send the final response let resp = JsonRpcResponse::success( id, @@ -536,24 +603,27 @@ pub async fn handle_reply(reply: &GatewayReply, registry: &AcpReplyRegistry) { return; } - let mut map = registry.lock().await; - let Some(tx) = map.get(key) else { - return; + let tx = { + let map = registry.lock().unwrap_or_else(|e| e.into_inner()); + match map.get(key) { + Some(tx) => tx.clone(), + None => return, + } }; match reply.command.as_deref() { Some("edit_message") => { // Streaming update — send as text snapshot if tx.send(ReplyChunk::Text(full_text)).is_err() { - tracing::debug!(channel = key, "ACP reply send failed (client likely disconnected)"); - map.remove(key); + debug!(channel = key, "ACP reply send failed (client likely disconnected)"); + registry.lock().unwrap_or_else(|e| e.into_inner()).remove(key); } } None | Some("send_message") => { // Final message let _ = tx.send(ReplyChunk::Text(full_text)); let _ = tx.send(ReplyChunk::Done); - map.remove(key); + registry.lock().unwrap_or_else(|e| e.into_inner()).remove(key); } Some("add_reaction") | Some("remove_reaction") => { // Reactions are agent state indicators — could map to notifications later