diff --git a/Cargo.lock b/Cargo.lock index 50d61a82..83128d89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,6 +393,61 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.8.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backon" version = "1.6.0" @@ -2574,6 +2629,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.6" @@ -3207,6 +3268,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "axum", "backon", "bytes", "cacache", @@ -3234,6 +3296,7 @@ dependencies = [ "tempfile", "thiserror 2.0.17", "tokio", + "tower-http 0.5.2", "tracing", "url", ] @@ -3268,6 +3331,7 @@ dependencies = [ "paws_app", "paws_common", "paws_domain", + "paws_infra", "paws_services", "pretty_assertions", "reedline", @@ -3929,7 +3993,7 @@ dependencies = [ "tokio-rustls", "tokio-util", "tower", - "tower-http", + "tower-http 0.6.8", "tower-service", "url", "wasm-bindgen", @@ -4966,6 +5030,24 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", ] [[package]] @@ -5004,6 +5086,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index 91ff0fbf..0ac806ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,9 @@ strip = true anyhow = "1.0.99" async-recursion = "1.1.1" async-trait = "0.1.89" + +axum = "0.7.5" +tower-http = "0.5.2" aws-config = { version = "1.8.12", features = ["behavior-version-latest"], default-features = false } aws-sdk-bedrockruntime = { version = "1.120.0", features = ["behavior-version-latest"], default-features = false } aws-smithy-types = "1.3" diff --git a/crates/paws_app/src/agent_protocol_service.rs b/crates/paws_app/src/agent_protocol_service.rs new file mode 100644 index 00000000..cd4f916c --- /dev/null +++ b/crates/paws_app/src/agent_protocol_service.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Result; +use paws_domain::{Step, StepInput, Task}; +use tokio::sync::RwLock; + +pub struct AgentProtocolService { + tasks: Arc>>, + steps: Arc>>>, +} + +impl AgentProtocolService { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + steps: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn create_task(&self, input: String) -> Task { + let task = Task::new(input); + let mut tasks = self.tasks.write().await; + tasks.insert(task.task_id.clone(), task.clone()); + task + } + + pub async fn list_tasks(&self) -> Vec { + let tasks = self.tasks.read().await; + tasks.values().cloned().collect() + } + + pub async fn get_task(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks.get(task_id).cloned() + } + + pub async fn list_steps(&self, task_id: &str) -> Vec { + let steps = self.steps.read().await; + steps.get(task_id).cloned().unwrap_or_default() + } + + pub async fn create_step(&self, task_id: &str, input: StepInput) -> Result { + let mut steps_guard = self.steps.write().await; + let task_steps = steps_guard.entry(task_id.to_string()).or_default(); + + let step = Step::new(task_id.to_string(), input.input, true); + task_steps.push(step.clone()); + + Ok(step) + } + + pub async fn get_step(&self, task_id: &str, step_id: &str) -> Option { + let steps = self.steps.read().await; + steps + .get(task_id) + .and_then(|steps| steps.iter().find(|s| s.step_id == step_id).cloned()) + } +} + +impl Default for AgentProtocolService { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/paws_app/src/lib.rs b/crates/paws_app/src/lib.rs index 66e77d6e..04a13889 100644 --- a/crates/paws_app/src/lib.rs +++ b/crates/paws_app/src/lib.rs @@ -1,5 +1,6 @@ mod agent; mod agent_executor; +pub mod agent_protocol_service; mod agent_provider_resolver; mod app; mod apply_tunable_parameters; @@ -39,6 +40,8 @@ pub mod utils; mod walker; pub use agent::*; +pub use agent_executor::*; +pub use agent_protocol_service::*; pub use agent_provider_resolver::*; pub use app::*; pub use command_generator::*; diff --git a/crates/paws_domain/src/agent_protocol.rs b/crates/paws_domain/src/agent_protocol.rs new file mode 100644 index 00000000..6a58ff9a --- /dev/null +++ b/crates/paws_domain/src/agent_protocol.rs @@ -0,0 +1,73 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub task_id: String, + pub input: String, + pub artifacts: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInput { + pub input: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Step { + pub task_id: String, + pub step_id: String, + pub name: Option, + pub input: Option, + pub output: Option, + pub status: StepStatus, + pub is_last: bool, + pub artifacts: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum StepStatus { + Created, + Running, + Completed, + Failed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepInput { + pub name: Option, + pub input: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Artifact { + pub artifact_id: String, + pub file_name: String, + pub relative_path: Option, +} + +impl Task { + pub fn new(input: String) -> Self { + Self { + task_id: Uuid::new_v4().to_string(), + input, + artifacts: Vec::new(), + } + } +} + +impl Step { + pub fn new(task_id: String, input: Option, is_last: bool) -> Self { + Self { + task_id, + step_id: Uuid::new_v4().to_string(), + name: None, + input, + output: None, + status: StepStatus::Created, + is_last, + artifacts: Vec::new(), + } + } +} diff --git a/crates/paws_domain/src/lib.rs b/crates/paws_domain/src/lib.rs index 6435df52..bbf8353b 100644 --- a/crates/paws_domain/src/lib.rs +++ b/crates/paws_domain/src/lib.rs @@ -1,5 +1,6 @@ mod agent; mod agent_definition; +pub mod agent_protocol; mod app_config; mod attachment; mod auth; @@ -53,6 +54,7 @@ mod xml; pub use agent::*; pub use agent_definition::*; +pub use agent_protocol::*; pub use attachment::*; pub use chat_request::*; pub use chat_response::*; diff --git a/crates/paws_infra/Cargo.toml b/crates/paws_infra/Cargo.toml index f2fe42cd..6e3d7e1c 100644 --- a/crates/paws_infra/Cargo.toml +++ b/crates/paws_infra/Cargo.toml @@ -100,6 +100,13 @@ workspace = true [dependencies.paws_common] workspace = true +[dependencies.axum] +workspace = true + +[dependencies.tower-http] +workspace = true +features = ["cors", "trace"] + [dev-dependencies.tokio] workspace = true features = [ "macros", "rt", "time", "test-util",] diff --git a/crates/paws_infra/src/agent_protocol_server.rs b/crates/paws_infra/src/agent_protocol_server.rs new file mode 100644 index 00000000..7df4d95a --- /dev/null +++ b/crates/paws_infra/src/agent_protocol_server.rs @@ -0,0 +1,94 @@ +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use paws_app::AgentProtocolService; +use paws_domain::{Step, StepInput, Task, TaskInput}; +use tower_http::cors::CorsLayer; + +pub struct AgentProtocolServer { + service: Arc, +} + +impl AgentProtocolServer { + pub fn new() -> Self { + Self { service: Arc::new(AgentProtocolService::new()) } + } + + pub async fn serve(&self, host: &str, port: u16) -> anyhow::Result<()> { + let app = Router::new() + .route("/agent/tasks", post(create_task).get(list_tasks)) + .route("/agent/tasks/:task_id", get(get_task)) + .route( + "/agent/tasks/:task_id/steps", + post(execute_step).get(list_steps), + ) + .route("/agent/tasks/:task_id/steps/:step_id", get(get_step)) + .layer(CorsLayer::permissive()) + .with_state(self.service.clone()); + + let listener = tokio::net::TcpListener::bind(format!("{}:{}", host, port)).await?; + tracing::info!("Agent Protocol server listening on {}:{}", host, port); + axum::serve(listener, app).await?; + Ok(()) + } +} + +async fn create_task( + State(service): State>, + Json(input): Json, +) -> Json { + Json(service.create_task(input.input).await) +} + +async fn list_tasks(State(service): State>) -> Json> { + Json(service.list_tasks().await) +} + +async fn get_task( + State(service): State>, + Path(task_id): Path, +) -> Result, axum::http::StatusCode> { + service + .get_task(&task_id) + .await + .map(Json) + .ok_or(axum::http::StatusCode::NOT_FOUND) +} + +async fn list_steps( + State(service): State>, + Path(task_id): Path, +) -> Json> { + Json(service.list_steps(&task_id).await) +} + +async fn execute_step( + State(service): State>, + Path(task_id): Path, + Json(input): Json, +) -> Result, axum::http::StatusCode> { + service + .create_step(&task_id, input) + .await + .map(Json) + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR) +} + +async fn get_step( + State(service): State>, + Path((task_id, step_id)): Path<(String, String)>, +) -> Result, axum::http::StatusCode> { + service + .get_step(&task_id, &step_id) + .await + .map(Json) + .ok_or(axum::http::StatusCode::NOT_FOUND) +} + +impl Default for AgentProtocolServer { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/paws_infra/src/lib.rs b/crates/paws_infra/src/lib.rs index 7b9fce9b..dfd28aed 100644 --- a/crates/paws_infra/src/lib.rs +++ b/crates/paws_infra/src/lib.rs @@ -1,5 +1,6 @@ pub mod executor; +pub mod agent_protocol_server; mod auth; mod env; mod error; diff --git a/crates/paws_main/Cargo.toml b/crates/paws_main/Cargo.toml index 5a85dbaf..79a841fc 100644 --- a/crates/paws_main/Cargo.toml +++ b/crates/paws_main/Cargo.toml @@ -110,6 +110,9 @@ workspace = true [dependencies.async-recursion] workspace = true +[dependencies.paws_infra] +workspace = true + [dependencies.paws_common] workspace = true diff --git a/crates/paws_main/src/cli.rs b/crates/paws_main/src/cli.rs index cd1457cd..cb15a6b7 100644 --- a/crates/paws_main/src/cli.rs +++ b/crates/paws_main/src/cli.rs @@ -133,6 +133,17 @@ pub enum TopLevelCommand { /// Manage API provider authentication. Provider(ProviderCommandGroup), + /// Start Agent Protocol server. + Serve { + /// Port to listen on. + #[arg(long, short = 'p', default_value = "8000")] + port: u16, + + /// Host to bind to. + #[arg(long, short = 'H', default_value = "127.0.0.1")] + host: String, + }, + /// Run or list custom commands. Cmd(CmdCommandGroup), @@ -1267,6 +1278,28 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_serve_command_args() { + let fixture = Cli::parse_from(["paws", "serve", "--port", "9000", "--host", "0.0.0.0"]); + let (port, host) = match fixture.subcommands { + Some(TopLevelCommand::Serve { port, host }) => (port, host), + _ => panic!("Expected TopLevelCommand::Serve"), + }; + assert_eq!(port, 9000); + assert_eq!(host, "0.0.0.0"); + } + + #[test] + fn test_serve_command_defaults() { + let fixture = Cli::parse_from(["paws", "serve"]); + let (port, host) = match fixture.subcommands { + Some(TopLevelCommand::Serve { port, host }) => (port, host), + _ => panic!("Expected TopLevelCommand::Serve"), + }; + assert_eq!(port, 8000); + assert_eq!(host, "127.0.0.1"); + } + #[test] fn test_prompt_with_leading_hyphen() { let fixture = Cli::parse_from(["paws", "-p", "- hi"]); diff --git a/crates/paws_main/src/ui.rs b/crates/paws_main/src/ui.rs index 3eaa1221..552b81da 100644 --- a/crates/paws_main/src/ui.rs +++ b/crates/paws_main/src/ui.rs @@ -640,6 +640,12 @@ impl A + Send + Sync> UI { self.writeln(data?)?; } } + + TopLevelCommand::Serve { port, host } => { + use paws_infra::agent_protocol_server::AgentProtocolServer; + let server = AgentProtocolServer::new(); + server.serve(&host, port).await?; + } } Ok(()) }