diff --git a/Cargo.lock b/Cargo.lock index 87adc5e2b..d7ccd89b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3683,6 +3683,7 @@ version = "0.0.0" dependencies = [ "anyhow", "axum 0.8.9", + "base64 0.22.1", "bytes", "clap", "futures", @@ -3738,6 +3739,7 @@ dependencies = [ "tower-http 0.6.8", "tracing", "tracing-subscriber", + "url", "uuid", "wiremock", "x509-parser", diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index bd4262b31..313ebda30 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -155,6 +155,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 7102ed9b6..9c9c977e9 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -130,6 +130,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 49b933e67..809538c51 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -239,6 +239,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 1ad00dd6e..26227f3de 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -175,6 +175,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, request: tonic::Request, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 531599dcf..3cf941fe3 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -152,6 +152,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-providers/src/lib.rs b/crates/openshell-providers/src/lib.rs index 3b28030ca..03534f646 100644 --- a/crates/openshell-providers/src/lib.rs +++ b/crates/openshell-providers/src/lib.rs @@ -91,6 +91,7 @@ impl ProviderRegistry { registry.register(providers::nvidia::NvidiaProvider); registry.register(providers::gitlab::GitlabProvider); registry.register(providers::github::GithubProvider); + registry.register(providers::microsoft_agent_s2s::MicrosoftAgentS2sProvider); registry.register(providers::outlook::OutlookProvider); registry } @@ -153,6 +154,7 @@ pub fn normalize_provider_type(input: &str) -> Option<&'static str> { "nvidia" => Some("nvidia"), "gitlab" | "glab" => Some("gitlab"), "github" | "gh" => Some("github"), + "microsoft-agent-s2s" => Some("microsoft-agent-s2s"), "outlook" => Some("outlook"), _ => None, } @@ -183,6 +185,10 @@ mod tests { assert_eq!(normalize_provider_type("anthropic"), Some("anthropic")); assert_eq!(normalize_provider_type("nvidia"), Some("nvidia")); assert_eq!(normalize_provider_type("copilot"), Some("copilot")); + assert_eq!( + normalize_provider_type("microsoft-agent-s2s"), + Some("microsoft-agent-s2s") + ); assert_eq!(normalize_provider_type("unknown"), None); } diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 588e77702..dbbe13934 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -21,6 +21,7 @@ const BUILT_IN_PROFILE_YAMLS: &[&str] = &[ include_str!("../../../providers/copilot.yaml"), include_str!("../../../providers/github.yaml"), include_str!("../../../providers/gitlab.yaml"), + include_str!("../../../providers/microsoft-agent-s2s.yaml"), include_str!("../../../providers/nvidia.yaml"), include_str!("../../../providers/openai.yaml"), include_str!("../../../providers/opencode.yaml"), @@ -883,6 +884,23 @@ mod tests { assert_eq!(proto.binaries.len(), 4); } + #[test] + fn microsoft_agent_s2s_profile_is_available() { + let profile = + get_default_profile("microsoft-agent-s2s").expect("microsoft-agent-s2s profile"); + let proto = profile.to_proto(); + + assert_eq!(proto.id, "microsoft-agent-s2s"); + assert_eq!(proto.category, ProviderProfileCategory::Agent as i32); + assert_eq!(proto.credentials.len(), 1); + assert_eq!(proto.credentials[0].name, "blueprint_client_secret"); + assert_eq!( + proto.credentials[0].env_vars, + vec!["A365_BLUEPRINT_CLIENT_SECRET"] + ); + assert!(proto.endpoints.is_empty()); + } + #[test] fn credential_env_vars_are_deduplicated_in_profile_order() { let profile = get_default_profile("copilot").expect("copilot profile"); diff --git a/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs b/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs new file mode 100644 index 000000000..2e9a19a9c --- /dev/null +++ b/crates/openshell-providers/src/providers/microsoft_agent_s2s.rs @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. +// SPDX-License-Identifier: Apache-2.0 + +use crate::DiscoveryContext; +use crate::{DiscoveredProvider, ProviderError, ProviderPlugin, RealDiscoveryContext}; + +pub struct MicrosoftAgentS2sProvider; + +const CREDENTIAL_ENV_VARS: &[&str] = &["A365_BLUEPRINT_CLIENT_SECRET"]; +const CONFIG_ENV_VARS: &[&str] = &[ + "AZURE_TENANT_ID", + "A365_BLUEPRINT_CLIENT_ID", + "A365_RUNTIME_AGENT_ID", + "A365_ALLOWED_AUDIENCES", + "A365_OBSERVABILITY_RESOURCE", + "A365_REQUIRED_ROLES", +]; + +impl ProviderPlugin for MicrosoftAgentS2sProvider { + fn id(&self) -> &'static str { + "microsoft-agent-s2s" + } + + fn discover_existing(&self) -> Result, ProviderError> { + discover_microsoft_agent_s2s(&RealDiscoveryContext) + } + + fn credential_env_vars(&self) -> &'static [&'static str] { + CREDENTIAL_ENV_VARS + } +} + +fn discover_microsoft_agent_s2s( + context: &dyn DiscoveryContext, +) -> Result, ProviderError> { + let mut discovered = DiscoveredProvider::default(); + + for key in CREDENTIAL_ENV_VARS { + if let Some(value) = context.env_var(key) + && !value.trim().is_empty() + { + discovered + .credentials + .entry((*key).to_string()) + .or_insert(value); + } + } + + for key in CONFIG_ENV_VARS { + if let Some(value) = context.env_var(key) + && !value.trim().is_empty() + { + discovered.config.entry((*key).to_string()).or_insert(value); + } + } + + if discovered.is_empty() { + Ok(None) + } else { + Ok(Some(discovered)) + } +} + +#[cfg(test)] +mod tests { + use super::discover_microsoft_agent_s2s; + use crate::test_helpers::MockDiscoveryContext; + + #[test] + fn discovers_microsoft_agent_s2s_env_state() { + let ctx = MockDiscoveryContext::new() + .with_env("AZURE_TENANT_ID", "tenant-id") + .with_env("A365_BLUEPRINT_CLIENT_ID", "blueprint-client-id") + .with_env("A365_BLUEPRINT_CLIENT_SECRET", "blueprint-secret") + .with_env("A365_RUNTIME_AGENT_ID", "runtime-agent-id") + .with_env("A365_ALLOWED_AUDIENCES", "api://aud-a,api://aud-b") + .with_env("A365_OBSERVABILITY_RESOURCE", "observability-resource") + .with_env("A365_REQUIRED_ROLES", "Agent365.Observability.OtelWrite"); + let discovered = discover_microsoft_agent_s2s(&ctx) + .expect("discovery") + .expect("provider"); + assert_eq!( + discovered.credentials.get("A365_BLUEPRINT_CLIENT_SECRET"), + Some(&"blueprint-secret".to_string()) + ); + assert_eq!( + discovered.config.get("AZURE_TENANT_ID"), + Some(&"tenant-id".to_string()) + ); + assert_eq!( + discovered.config.get("A365_BLUEPRINT_CLIENT_ID"), + Some(&"blueprint-client-id".to_string()) + ); + assert_eq!( + discovered.config.get("A365_RUNTIME_AGENT_ID"), + Some(&"runtime-agent-id".to_string()) + ); + assert_eq!( + discovered.config.get("A365_ALLOWED_AUDIENCES"), + Some(&"api://aud-a,api://aud-b".to_string()) + ); + assert_eq!( + discovered.config.get("A365_OBSERVABILITY_RESOURCE"), + Some(&"observability-resource".to_string()) + ); + assert_eq!( + discovered.config.get("A365_REQUIRED_ROLES"), + Some(&"Agent365.Observability.OtelWrite".to_string()) + ); + } +} diff --git a/crates/openshell-providers/src/providers/mod.rs b/crates/openshell-providers/src/providers/mod.rs index 6fe395135..966c7058b 100644 --- a/crates/openshell-providers/src/providers/mod.rs +++ b/crates/openshell-providers/src/providers/mod.rs @@ -8,6 +8,7 @@ pub mod copilot; pub mod generic; pub mod github; pub mod gitlab; +pub mod microsoft_agent_s2s; pub mod nvidia; pub mod openai; pub mod opencode; diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 28492b543..7d7c9deb0 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -10,7 +10,8 @@ use std::time::Duration; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, - GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, PolicyChunk, PolicySource, + GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, MintSandboxProviderTokenRequest, + PolicyChunk, PolicySource, PolicyStatus, ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UpdateConfigRequest, inference_client::InferenceClient, open_shell_client::OpenShellClient, @@ -230,6 +231,39 @@ pub async fn fetch_provider_environment( }) } +pub async fn mint_provider_token( + endpoint: &str, + sandbox_id: &str, + provider_name: &str, + audience: &str, +) -> Result { + debug!( + endpoint = %endpoint, + sandbox_id = %sandbox_id, + provider_name = %provider_name, + audience = %audience, + "Minting sandbox provider token" + ); + + let mut client = connect(endpoint).await?; + let response = client + .mint_sandbox_provider_token(MintSandboxProviderTokenRequest { + sandbox_id: sandbox_id.to_string(), + provider_name: provider_name.to_string(), + audience: audience.to_string(), + }) + .await + .into_diagnostic()?; + + let inner = response.into_inner(); + Ok(MintedProviderToken { + access_token: inner.access_token, + token_type: inner.token_type, + expires_at_unix: inner.expires_at_unix, + cache_hit: inner.cache_hit, + }) +} + /// A reusable gRPC client for the `OpenShell` service. /// /// Wraps a tonic channel connected once and reused for policy polling @@ -258,6 +292,13 @@ pub struct ProviderEnvironmentResult { pub provider_env_revision: u64, } +pub struct MintedProviderToken { + pub access_token: String, + pub token_type: String, + pub expires_at_unix: u64, + pub cache_hit: bool, +} + impl CachedOpenShellClient { pub async fn connect(endpoint: &str) -> Result { debug!(endpoint = %endpoint, "Connecting openshell gRPC client for policy polling"); diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index e297b9262..4c3a12a8b 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -19,6 +19,7 @@ mod policy_local; mod process; pub mod procfs; mod provider_credentials; +mod provider_tokens; pub mod proxy; mod sandbox; mod secrets; @@ -351,7 +352,7 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let (provider_env_revision, provider_env) = + let (provider_env_revision, mut provider_env) = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { match grpc_client::fetch_provider_environment(endpoint, id).await { Ok(result) => { @@ -385,12 +386,8 @@ pub async fn run_sandbox( } else { (0, std::collections::HashMap::new()) }; - - let provider_credentials = provider_credentials::ProviderCredentialState::from_environment( - provider_env_revision, - provider_env, - ); - let provider_env = provider_credentials.snapshot().child_env.clone(); + let provider_token_resolver_port = + provider_tokens::microsoft_agent_s2s_resolver_port(&provider_env); // Create identity cache for SHA256 TOFU when OPA is active let identity_cache = opa_engine @@ -518,7 +515,11 @@ pub async fn run_sandbox( .as_ref() .and_then(|p| p.http_addr) .map_or(3128, |addr| addr.port()); - if let Err(e) = ns.install_bypass_rules(proxy_port) { + let provider_token_ports = provider_token_resolver_port + .iter() + .copied() + .collect::>(); + if let Err(e) = ns.install_bypass_rules(proxy_port, &provider_token_ports) { ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Medium) @@ -549,6 +550,56 @@ pub async fn run_sandbox( #[allow(clippy::no_effect_underscore_binding)] let _netns: Option<()> = None; + #[cfg(target_os = "linux")] + let provider_token_resolver_bind_addr = { + let ip = netns.as_ref().map_or( + std::net::IpAddr::from([127, 0, 0, 1]), + NetworkNamespace::host_ip, + ); + SocketAddr::new(ip, provider_token_resolver_port.unwrap_or(0)) + }; + + #[cfg(not(target_os = "linux"))] + let provider_token_resolver_bind_addr = + SocketAddr::from(([127, 0, 0, 1], provider_token_resolver_port.unwrap_or(0))); + + let prepared_provider_tokens = if let (Some(endpoint), Some(id)) = + (openshell_endpoint.as_deref(), sandbox_id.as_deref()) + { + provider_tokens::prepare_microsoft_agent_s2s( + &mut provider_env, + provider_token_resolver_bind_addr, + endpoint, + id, + ) + .await? + } else { + provider_tokens::strip_microsoft_agent_s2s_inputs(&mut provider_env); + provider_tokens::PreparedProviderTokenResolver { + environment: std::collections::HashMap::new(), + handle: None, + } + }; + if !prepared_provider_tokens.environment.is_empty() { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enabled") + .message("Started microsoft-agent-s2s provider token resolver") + .build() + ); + } + let provider_runtime_env = prepared_provider_tokens.environment; + let _provider_token_resolver = prepared_provider_tokens.handle; + + let provider_credentials = provider_credentials::ProviderCredentialState::from_environment( + provider_env_revision, + provider_env, + ); + let mut provider_env = provider_credentials.snapshot().child_env.clone(); + provider_env.extend(provider_runtime_env.clone()); + // Install the supervisor seccomp prelude after privileged startup helpers // (network namespace setup, iptables probes) complete, but before the SSH // listener and workload process are exposed. @@ -748,6 +799,7 @@ pub async fn run_sandbox( let netns_fd = ssh_netns_fd; let ca_paths = ca_file_paths.clone(); let provider_credentials_clone = provider_credentials.clone(); + let provider_runtime_env_clone = provider_runtime_env.clone(); let (ssh_ready_tx, ssh_ready_rx) = tokio::sync::oneshot::channel(); @@ -761,6 +813,7 @@ pub async fn run_sandbox( proxy_url, ca_paths, provider_credentials_clone, + provider_runtime_env_clone, ) .await { @@ -925,6 +978,7 @@ pub async fn run_sandbox( let poll_ocsf_enabled = ocsf_enabled.clone(); let poll_pid = entrypoint_pid.clone(); let poll_provider_credentials = provider_credentials.clone(); + let poll_provider_runtime_env = provider_runtime_env.clone(); let poll_policy_local = policy_local_ctx.clone(); let poll_interval_secs: u64 = std::env::var("OPENSHELL_POLICY_POLL_INTERVAL_SECS") .ok() @@ -938,6 +992,7 @@ pub async fn run_sandbox( interval_secs: poll_interval_secs, ocsf_enabled: poll_ocsf_enabled, provider_credentials: poll_provider_credentials, + provider_runtime_env: poll_provider_runtime_env, policy_local_ctx: Some(poll_policy_local), }; @@ -2284,6 +2339,7 @@ struct PolicyPollLoopContext { interval_secs: u64, ocsf_enabled: Arc, provider_credentials: provider_credentials::ProviderCredentialState, + provider_runtime_env: std::collections::HashMap, policy_local_ctx: Option>, } @@ -2355,11 +2411,12 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { if provider_env_changed { match grpc_client::fetch_provider_environment(&ctx.endpoint, &ctx.sandbox_id).await { - Ok(env_result) => { + Ok(mut env_result) => { + provider_tokens::strip_microsoft_agent_s2s_inputs(&mut env_result.environment); let env_count = ctx.provider_credentials.install_environment( env_result.provider_env_revision, env_result.environment, - ); + ) + ctx.provider_runtime_env.len(); current_provider_env_revision = env_result.provider_env_revision; ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) diff --git a/crates/openshell-sandbox/src/provider_tokens.rs b/crates/openshell-sandbox/src/provider_tokens.rs new file mode 100644 index 000000000..563665f99 --- /dev/null +++ b/crates/openshell-sandbox/src/provider_tokens.rs @@ -0,0 +1,576 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox-local provider token resolvers. + +use crate::grpc_client; +use miette::{IntoDiagnostic, Result, WrapErr}; +use serde::Deserialize; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; + +const MAX_REQUEST_HEADER_BYTES: usize = 8192; +const MAX_REQUEST_BODY_BYTES: usize = 16 * 1024; +const MICROSOFT_AGENT_S2S_TOKEN_PATH: &str = "/v1/microsoft-agent-s2s/token"; +const MICROSOFT_AGENT_S2S_RESOLVER_PORT: u16 = 3130; +const TOKEN_URL_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_TOKEN_URL"; +const TOKEN_PROVIDER_URL_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_TOKEN_PROVIDER_URL"; +const DEFAULT_AUDIENCE_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_DEFAULT_AUDIENCE"; +const A365_TOKEN_PROVIDER_URL_ENV: &str = "A365_TOKEN_PROVIDER_URL"; +const PROVIDER_NAME_ENV: &str = "OPENSHELL_MICROSOFT_AGENT_S2S_PROVIDER_NAME"; +const SANDBOX_HOST_BYPASS_IP: &str = "10.200.0.1"; + +const MICROSOFT_AGENT_S2S_KEYS: &[&str] = &[ + "AZURE_TENANT_ID", + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", + "A365_ALLOWED_AUDIENCES", + "A365_OBSERVABILITY_RESOURCE", + "A365_REQUIRED_ROLES", + PROVIDER_NAME_ENV, +]; + +const MICROSOFT_AGENT_S2S_MARKER_KEYS: &[&str] = &[ + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", + PROVIDER_NAME_ENV, +]; + +pub(crate) struct PreparedProviderTokenResolver { + pub environment: HashMap, + pub handle: Option, +} + +pub(crate) fn microsoft_agent_s2s_resolver_port( + provider_env: &HashMap, +) -> Option { + contains_microsoft_agent_s2s_inputs(provider_env).then_some(MICROSOFT_AGENT_S2S_RESOLVER_PORT) +} + +pub(crate) fn strip_microsoft_agent_s2s_inputs( + provider_env: &mut HashMap, +) -> Option { + if !contains_microsoft_agent_s2s_inputs(provider_env) { + return None; + } + + let provider_name = provider_env.get(PROVIDER_NAME_ENV).cloned(); + for key in MICROSOFT_AGENT_S2S_KEYS { + provider_env.remove(*key); + } + provider_name +} + +#[derive(Debug)] +pub(crate) struct ProviderTokenResolverHandle { + local_addr: SocketAddr, + token_path: String, + join: JoinHandle<()>, +} + +impl ProviderTokenResolverHandle { + fn url(&self) -> String { + format!("http://{}{}", self.local_addr, self.token_path) + } +} + +impl Drop for ProviderTokenResolverHandle { + fn drop(&mut self) { + self.join.abort(); + } +} + +pub(crate) async fn prepare_microsoft_agent_s2s( + raw_provider_env: &mut HashMap, + bind_addr: SocketAddr, + endpoint: &str, + sandbox_id: &str, +) -> Result { + if !contains_microsoft_agent_s2s_inputs(raw_provider_env) { + return Ok(PreparedProviderTokenResolver { + environment: HashMap::new(), + handle: None, + }); + } + + let provider_name = raw_provider_env + .get(PROVIDER_NAME_ENV) + .cloned() + .filter(|name| !name.trim().is_empty()) + .ok_or_else(|| miette::miette!("missing microsoft-agent-s2s provider name"))?; + let default_audience = default_audience(raw_provider_env); + let handle = start_microsoft_agent_s2s_resolver( + endpoint.to_string(), + sandbox_id.to_string(), + provider_name, + default_audience.clone(), + bind_addr, + ) + .await?; + + strip_microsoft_agent_s2s_inputs(raw_provider_env); + let environment = resolver_environment(handle.url(), default_audience); + + Ok(PreparedProviderTokenResolver { + environment, + handle: Some(handle), + }) +} + +fn contains_microsoft_agent_s2s_inputs(provider_env: &HashMap) -> bool { + MICROSOFT_AGENT_S2S_MARKER_KEYS + .iter() + .any(|key| provider_env.contains_key(*key)) +} + +fn default_audience(provider_env: &HashMap) -> Option { + provider_env + .get("A365_OBSERVABILITY_RESOURCE") + .cloned() + .or_else(|| { + provider_env + .get("A365_ALLOWED_AUDIENCES") + .map(|value| split_csv(value)) + .and_then(|values| match values.as_slice() { + [only] => Some(only.clone()), + _ => None, + }) + }) +} + +fn split_csv(value: &str) -> Vec { + value + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToString::to_string) + .collect() +} + +fn resolver_environment( + resolver_url: String, + default_audience: Option, +) -> HashMap { + let mut environment = HashMap::from([ + (TOKEN_URL_ENV.to_string(), resolver_url.clone()), + (TOKEN_PROVIDER_URL_ENV.to_string(), resolver_url.clone()), + (A365_TOKEN_PROVIDER_URL_ENV.to_string(), resolver_url), + ("NO_PROXY".to_string(), SANDBOX_HOST_BYPASS_IP.to_string()), + ("no_proxy".to_string(), SANDBOX_HOST_BYPASS_IP.to_string()), + ]); + if let Some(audience) = default_audience { + environment.insert(DEFAULT_AUDIENCE_ENV.to_string(), audience); + } + environment +} + +async fn start_microsoft_agent_s2s_resolver( + endpoint: String, + sandbox_id: String, + provider_name: String, + default_audience: Option, + bind_addr: SocketAddr, +) -> Result { + let listener = TcpListener::bind(bind_addr).await.into_diagnostic()?; + let local_addr = listener.local_addr().into_diagnostic()?; + let token_path = format!("{MICROSOFT_AGENT_S2S_TOKEN_PATH}/{}", uuid::Uuid::new_v4()); + let token_path_for_task = token_path.clone(); + + let join = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, _peer)) => { + let endpoint = endpoint.clone(); + let sandbox_id = sandbox_id.clone(); + let provider_name = provider_name.clone(); + let default_audience = default_audience.clone(); + let token_path = token_path_for_task.clone(); + tokio::spawn(async move { + if let Err(err) = handle_microsoft_agent_s2s_connection( + stream, + &endpoint, + &sandbox_id, + &provider_name, + default_audience, + token_path, + ) + .await + { + warn!(error = %err, "microsoft-agent-s2s token resolver request failed"); + } + }); + } + Err(err) => { + warn!(error = %err, "microsoft-agent-s2s token resolver accept failed"); + break; + } + } + } + }); + + Ok(ProviderTokenResolverHandle { + local_addr, + token_path, + join, + }) +} + +async fn handle_microsoft_agent_s2s_connection( + mut stream: TcpStream, + endpoint: &str, + sandbox_id: &str, + provider_name: &str, + default_audience: Option, + token_path: String, +) -> Result<()> { + let request = read_http_request(&mut stream).await?; + let response = match parse_token_request(&request, default_audience.as_deref(), &token_path) { + Ok(audience) => { + match grpc_client::mint_provider_token(endpoint, sandbox_id, provider_name, &audience) + .await + { + Ok(token) => json_response( + 200, + "OK", + serde_json::json!({ + "access_token": token.access_token, + "token_type": token.token_type, + "expires_at_unix": token.expires_at_unix, + "cache_hit": token.cache_hit, + }), + ), + Err(err) => json_response( + 502, + "Bad Gateway", + serde_json::json!({ "error": err.to_string() }), + ), + } + } + Err(err) => err.into_response(), + }; + stream + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + Ok(()) +} + +async fn read_http_request(stream: &mut TcpStream) -> Result { + let mut buffer = Vec::new(); + let mut chunk = [0_u8; 1024]; + loop { + let read = stream.read(&mut chunk).await.into_diagnostic()?; + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if buffer.windows(4).any(|window| window == b"\r\n\r\n") { + break; + } + if buffer.len() > MAX_REQUEST_HEADER_BYTES { + return Err(miette::miette!("token resolver request headers too large")); + } + } + + let header_end = buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|index| index + 4) + .ok_or_else(|| miette::miette!("incomplete HTTP request"))?; + + let content_length = content_length(&buffer[..header_end])?; + if content_length > MAX_REQUEST_BODY_BYTES { + return Err(miette::miette!("token resolver request body too large")); + } + + while buffer.len().saturating_sub(header_end) < content_length { + let read = stream.read(&mut chunk).await.into_diagnostic()?; + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if buffer.len().saturating_sub(header_end) > MAX_REQUEST_BODY_BYTES { + return Err(miette::miette!("token resolver request body too large")); + } + } + + if buffer.len().saturating_sub(header_end) < content_length { + return Err(miette::miette!("incomplete HTTP request body")); + } + + String::from_utf8(buffer).into_diagnostic() +} + +fn content_length(headers: &[u8]) -> Result { + let headers = std::str::from_utf8(headers).into_diagnostic()?; + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if name.eq_ignore_ascii_case("content-length") { + return value + .trim() + .parse::() + .into_diagnostic() + .wrap_err("invalid content-length header"); + } + } + Ok(0) +} + +fn parse_token_request( + request: &str, + default_audience: Option<&str>, + expected_path: &str, +) -> std::result::Result { + let parsed = ParsedHttpRequest::parse(request)?; + let method = parsed.method.as_str(); + let target = parsed.target.as_str(); + + if method != "GET" && method != "POST" { + return Err(HttpError::new( + 405, + "Method Not Allowed", + "method not allowed", + )); + } + + let (path, query) = target + .split_once('?') + .map_or((target, ""), |(path, query)| (path, query)); + if path != expected_path { + return Err(HttpError::new(404, "Not Found", "token endpoint not found")); + } + + let audience = if method == "GET" { + audience_from_query(query) + } else { + audience_from_json_body(parsed.body)? + } + .or_else(|| default_audience.map(ToOwned::to_owned)) + .ok_or_else(|| HttpError::new(400, "Bad Request", "audience is required"))?; + + if audience.trim().is_empty() { + return Err(HttpError::new( + 400, + "Bad Request", + "audience must not be empty", + )); + } + + debug!(audience = %audience, "microsoft-agent-s2s token resolver request accepted"); + Ok(audience) +} + +fn audience_from_query(query: &str) -> Option { + query.split('&').find_map(|entry| { + let (key, value) = entry.split_once('=')?; + (key == "audience").then(|| value.to_string()) + }) +} + +fn audience_from_json_body(body: &str) -> std::result::Result, HttpError> { + if body.trim().is_empty() { + return Ok(None); + } + + #[derive(Deserialize)] + struct TokenRequestBody { + audience: Option, + } + + serde_json::from_str::(body) + .map(|payload| payload.audience) + .map_err(|_| HttpError::new(400, "Bad Request", "invalid JSON request body")) +} + +struct ParsedHttpRequest<'a> { + method: String, + target: String, + body: &'a str, +} + +impl<'a> ParsedHttpRequest<'a> { + fn parse(request: &'a str) -> std::result::Result { + let Some((head, body)) = request.split_once("\r\n\r\n") else { + return Err(HttpError::new( + 400, + "Bad Request", + "missing HTTP request separator", + )); + }; + let request_line = head + .lines() + .next() + .ok_or_else(|| HttpError::new(400, "Bad Request", "missing HTTP request line"))?; + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default().to_string(); + let target = parts.next().unwrap_or_default().to_string(); + Ok(Self { + method, + target, + body, + }) + } +} + +#[derive(Debug)] +struct HttpError { + status: u16, + reason: &'static str, + message: &'static str, +} + +impl HttpError { + const fn new(status: u16, reason: &'static str, message: &'static str) -> Self { + Self { + status, + reason, + message, + } + } + + fn into_response(self) -> String { + json_response( + self.status, + self.reason, + serde_json::json!({ "error": self.message }), + ) + } +} + +fn json_response(status: u16, reason: &str, body: serde_json::Value) -> String { + let body = body.to_string(); + format!( + "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", + body.len() + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn strip_inputs_removes_microsoft_broker_material() { + let mut provider_env = HashMap::from([ + (PROVIDER_NAME_ENV.to_string(), "work-microsoft".to_string()), + ( + "A365_BLUEPRINT_CLIENT_SECRET".to_string(), + "secret".to_string(), + ), + ( + "A365_ALLOWED_AUDIENCES".to_string(), + "api://allowed".to_string(), + ), + ]); + + let provider_name = strip_microsoft_agent_s2s_inputs(&mut provider_env); + assert_eq!(provider_name.as_deref(), Some("work-microsoft")); + assert!(provider_env.is_empty()); + } + + #[test] + fn resolver_environment_exposes_only_local_token_metadata() { + let environment = resolver_environment( + "http://127.0.0.1:3130/v1/microsoft-agent-s2s/token/capability".to_string(), + Some("api://resource".to_string()), + ); + + assert_eq!( + environment.get(TOKEN_URL_ENV), + Some(&"http://127.0.0.1:3130/v1/microsoft-agent-s2s/token/capability".to_string()) + ); + assert_eq!( + environment.get(TOKEN_PROVIDER_URL_ENV), + environment.get(TOKEN_URL_ENV) + ); + assert_eq!( + environment.get(A365_TOKEN_PROVIDER_URL_ENV), + environment.get(TOKEN_URL_ENV) + ); + assert_eq!( + environment.get("NO_PROXY"), + Some(&SANDBOX_HOST_BYPASS_IP.to_string()) + ); + assert_eq!( + environment.get("no_proxy"), + Some(&SANDBOX_HOST_BYPASS_IP.to_string()) + ); + assert_eq!( + environment.get(DEFAULT_AUDIENCE_ENV), + Some(&"api://resource".to_string()) + ); + assert!(!environment.contains_key("A365_BLUEPRINT_CLIENT_SECRET")); + assert!(!environment.contains_key("A365_BLUEPRINT_CLIENT_ID")); + } + + #[test] + fn parse_token_request_accepts_post_json_body() { + let audience = parse_token_request( + "POST /v1/microsoft-agent-s2s/token/test HTTP/1.1\r\nContent-Length: 31\r\n\r\n{\"audience\":\"api://resource\"}", + None, + "/v1/microsoft-agent-s2s/token/test", + ) + .expect("audience"); + + assert_eq!(audience, "api://resource"); + } + + #[test] + fn parse_token_request_uses_default_audience_for_empty_post_body() { + let audience = parse_token_request( + "POST /v1/microsoft-agent-s2s/token/test HTTP/1.1\r\nContent-Length: 0\r\n\r\n", + Some("api://default"), + "/v1/microsoft-agent-s2s/token/test", + ) + .expect("audience"); + + assert_eq!(audience, "api://default"); + } + + #[test] + fn parse_token_request_rejects_invalid_json_body() { + let err = parse_token_request( + "POST /v1/microsoft-agent-s2s/token/test HTTP/1.1\r\nContent-Length: 9\r\n\r\nnot-json!", + None, + "/v1/microsoft-agent-s2s/token/test", + ) + .expect_err("invalid request should fail"); + + assert_eq!(err.status, 400); + assert_eq!(err.message, "invalid JSON request body"); + } + + #[test] + fn parse_token_request_rejects_unknown_path() { + let err = parse_token_request( + "GET /v1/microsoft-agent-s2s/debug HTTP/1.1\r\n\r\n", + Some("api://default"), + "/v1/microsoft-agent-s2s/token/test", + ) + .expect_err("unknown path should fail"); + + assert_eq!(err.status, 404); + assert_eq!(err.message, "token endpoint not found"); + } + + #[test] + fn parse_token_request_rejects_unsupported_method() { + let err = parse_token_request( + "DELETE /v1/microsoft-agent-s2s/token/test HTTP/1.1\r\n\r\n", + Some("api://default"), + "/v1/microsoft-agent-s2s/token/test", + ) + .expect_err("unsupported method should fail"); + + assert_eq!(err.status, 405); + assert_eq!(err.message, "method not allowed"); + } +} diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index 019036e53..5af33669a 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -259,7 +259,7 @@ impl NetworkNamespace { /// Degrades gracefully if `iptables` is not available — the namespace /// still provides isolation via routing, just without fast-fail and /// diagnostic logging. - pub fn install_bypass_rules(&self, proxy_port: u16) -> Result<()> { + pub fn install_bypass_rules(&self, proxy_port: u16, extra_allowed_ports: &[u16]) -> Result<()> { // Check if iptables is available before attempting to install rules. let Some(iptables_path) = find_iptables() else { openshell_ocsf::ocsf_emit!( @@ -289,6 +289,7 @@ impl NetworkNamespace { &host_ip_str, &proxy_port_str, &log_prefix, + extra_allowed_ports, ) { openshell_ocsf::ocsf_emit!( openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) @@ -344,6 +345,7 @@ impl NetworkNamespace { host_ip: &str, proxy_port: &str, log_prefix: &str, + extra_allowed_ports: &[u16], ) -> Result<()> { // Rule 1: ACCEPT traffic to the proxy run_iptables_netns( @@ -363,6 +365,25 @@ impl NetworkNamespace { ], )?; + for port in extra_allowed_ports { + run_iptables_netns( + &self.name, + iptables_cmd, + &[ + "-A", + "OUTPUT", + "-d", + &format!("{host_ip}/32"), + "-p", + "tcp", + "--dport", + &port.to_string(), + "-j", + "ACCEPT", + ], + )?; + } + // Rule 2: ACCEPT loopback traffic run_iptables_netns( &self.name, diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index c92180748..ddc245654 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -107,6 +107,7 @@ pub async fn run_ssh_server( proxy_url: Option, ca_file_paths: Option<(PathBuf, PathBuf)>, provider_credentials: ProviderCredentialState, + provider_runtime_env: HashMap, ) -> Result<()> { let (listener, config, ca_paths) = match ssh_server_init(&listen_path, &ca_file_paths) { Ok(v) => { @@ -131,6 +132,7 @@ pub async fn run_ssh_server( let proxy_url = proxy_url.clone(); let ca_paths = ca_paths.clone(); let provider_credentials = provider_credentials.clone(); + let provider_runtime_env = provider_runtime_env.clone(); tokio::spawn(async move { if let Err(err) = handle_connection( @@ -142,6 +144,7 @@ pub async fn run_ssh_server( proxy_url, ca_paths, provider_credentials, + provider_runtime_env, ) .await { @@ -168,6 +171,7 @@ async fn handle_connection( proxy_url: Option, ca_file_paths: Option>, provider_credentials: ProviderCredentialState, + provider_runtime_env: HashMap, ) -> Result<()> { // Access is gated by the Unix-socket filesystem permissions (root-only), // not by an application-level preface. The supervisor bridges the @@ -190,6 +194,7 @@ async fn handle_connection( proxy_url, ca_file_paths, provider_credentials, + provider_runtime_env, ); russh::server::run_stream(config, stream, handler) .await @@ -217,6 +222,7 @@ struct SshHandler { proxy_url: Option, ca_file_paths: Option>, provider_credentials: ProviderCredentialState, + provider_runtime_env: HashMap, channels: HashMap, } @@ -228,6 +234,7 @@ impl SshHandler { proxy_url: Option, ca_file_paths: Option>, provider_credentials: ProviderCredentialState, + provider_runtime_env: HashMap, ) -> Self { Self { policy, @@ -236,6 +243,7 @@ impl SshHandler { proxy_url, ca_file_paths, provider_credentials, + provider_runtime_env, channels: HashMap::new(), } } @@ -535,6 +543,8 @@ impl SshHandler { command: Option, ) -> anyhow::Result<()> { let provider_snapshot = self.provider_credentials.snapshot(); + let mut provider_env = provider_snapshot.child_env.clone(); + provider_env.extend(self.provider_runtime_env.clone()); let state = self .channels .get_mut(&channel) @@ -552,7 +562,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &provider_snapshot.child_env, + &provider_env, )?; state.pty_master = Some(pty_master); state.input_sender = Some(input_sender); @@ -569,7 +579,7 @@ impl SshHandler { self.netns_fd, self.proxy_url.clone(), self.ca_file_paths.clone(), - &provider_snapshot.child_env, + &provider_env, )?; state.input_sender = Some(input_sender); } diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 4bbfe24fc..6a5f6ba3e 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -78,6 +78,8 @@ toml = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } +base64 = { workspace = true } +url = { workspace = true } uuid = { workspace = true } hmac = "0.12" sha2 = { workspace = true } diff --git a/crates/openshell-server/src/auth/oidc.rs b/crates/openshell-server/src/auth/oidc.rs index 92298579e..a8643e668 100644 --- a/crates/openshell-server/src/auth/oidc.rs +++ b/crates/openshell-server/src/auth/oidc.rs @@ -48,6 +48,7 @@ const SANDBOX_METHODS: &[&str] = &[ "/openshell.v1.OpenShell/ReportPolicyStatus", "/openshell.v1.OpenShell/PushSandboxLogs", "/openshell.v1.OpenShell/GetSandboxProviderEnvironment", + "/openshell.v1.OpenShell/MintSandboxProviderToken", "/openshell.v1.OpenShell/SubmitPolicyAnalysis", "/openshell.sandbox.v1.SandboxService/GetSandboxConfig", "/openshell.inference.v1.Inference/GetInferenceBundle", diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 9ea8d7ece..a60bb87f3 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -29,7 +29,8 @@ use openshell_core::proto::{ ListProviderProfilesRequest, ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, - ListSandboxesResponse, ListServicesRequest, ListServicesResponse, ProviderProfileResponse, + ListSandboxesResponse, ListServicesRequest, ListServicesResponse, + MintSandboxProviderTokenRequest, MintSandboxProviderTokenResponse, ProviderProfileResponse, ProviderResponse, PushSandboxLogsRequest, PushSandboxLogsResponse, RejectDraftChunkRequest, RejectDraftChunkResponse, RelayFrame, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, @@ -401,6 +402,13 @@ impl OpenShell for OpenShellService { policy::handle_get_sandbox_provider_environment(&self.state, request).await } + async fn mint_sandbox_provider_token( + &self, + request: Request, + ) -> Result, Status> { + policy::handle_mint_sandbox_provider_token(&self.state, request).await + } + async fn update_config( &self, request: Request, diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 315b06f3c..cdc3b402a 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -25,12 +25,13 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, - ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, PolicyChunk, PolicyMergeOperation, - PolicySource, PolicyStatus, PushSandboxLogsRequest, PushSandboxLogsResponse, - RejectDraftChunkRequest, RejectDraftChunkResponse, ReportPolicyStatusRequest, - ReportPolicyStatusResponse, SandboxLogLine, SandboxPolicyRevision, SettingScope, SettingValue, - SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UndoDraftChunkRequest, - UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, + ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, MintSandboxProviderTokenRequest, + MintSandboxProviderTokenResponse, PolicyChunk, PolicyMergeOperation, PolicySource, + PolicyStatus, PushSandboxLogsRequest, PushSandboxLogsResponse, RejectDraftChunkRequest, + RejectDraftChunkResponse, ReportPolicyStatusRequest, ReportPolicyStatusResponse, + SandboxLogLine, SandboxPolicyRevision, SettingScope, SettingValue, SubmitPolicyAnalysisRequest, + SubmitPolicyAnalysisResponse, UndoDraftChunkRequest, UndoDraftChunkResponse, + UpdateConfigRequest, UpdateConfigResponse, }; use openshell_core::proto::{ L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, Provider, Sandbox, @@ -643,6 +644,87 @@ pub(super) async fn handle_get_sandbox_provider_environment( })) } +pub(super) async fn handle_mint_sandbox_provider_token( + state: &Arc, + request: Request, +) -> Result, Status> { + let request = request.into_inner(); + let sandbox_id = request.sandbox_id.trim(); + let provider_name = request.provider_name.trim(); + + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + if provider_name.is_empty() { + return Err(Status::invalid_argument("provider_name is required")); + } + + let sandbox = state + .store + .get_message::(sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + let spec = sandbox + .spec + .ok_or_else(|| Status::internal("sandbox has no spec"))?; + if !spec.providers.iter().any(|name| name == provider_name) { + return Err(Status::failed_precondition(format!( + "provider '{provider_name}' is not attached to sandbox" + ))); + } + + let provider = state + .store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider '{provider_name}' failed: {e}")))? + .ok_or_else(|| { + Status::failed_precondition(format!("provider '{provider_name}' not found")) + })?; + + if provider.r#type.trim() != "microsoft-agent-s2s" { + return Err(Status::failed_precondition(format!( + "provider '{provider_name}' is not a microsoft-agent-s2s provider" + ))); + } + + let broker = state + .microsoft_s2s_brokers + .broker_for_provider(provider_name, &provider) + .await + .map_err(|e| Status::failed_precondition(e.to_string()))?; + + let audience = if request.audience.trim().is_empty() { + broker + .default_audience() + .ok_or_else(|| Status::invalid_argument("audience is required"))? + } else { + request.audience.trim().to_string() + }; + + let token = broker + .access_token(&audience) + .await + .map_err(|e| Status::failed_precondition(e.to_string()))?; + + info!( + sandbox_id, + provider_name, + audience = %audience, + cache_hit = token.cache_hit, + "MintSandboxProviderToken request completed successfully" + ); + + Ok(Response::new(MintSandboxProviderTokenResponse { + access_token: token.access_token, + token_type: "Bearer".to_string(), + expires_at_unix: token.expires_at_unix.unwrap_or_default(), + cache_hit: token.cache_hit, + })) +} + // --------------------------------------------------------------------------- // Update config handler (policy + settings mutations) // --------------------------------------------------------------------------- @@ -2890,6 +2972,38 @@ mod tests { } } + fn test_microsoft_provider(name: &str) -> Provider { + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: format!("provider-{name}"), + name: name.to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + }), + r#type: "microsoft-agent-s2s".to_string(), + credentials: std::iter::once(( + "A365_BLUEPRINT_CLIENT_SECRET".to_string(), + "secret".to_string(), + )) + .collect(), + config: HashMap::from([ + ("AZURE_TENANT_ID".to_string(), "tenant-id".to_string()), + ( + "A365_BLUEPRINT_CLIENT_ID".to_string(), + "blueprint-client-id".to_string(), + ), + ( + "A365_RUNTIME_AGENT_ID".to_string(), + "runtime-agent-id".to_string(), + ), + ( + "A365_ALLOWED_AUDIENCES".to_string(), + "api://allowed".to_string(), + ), + ]), + } + } + fn test_policy_with_rule(rule_name: &str, host: &str) -> ProtoSandboxPolicy { ProtoSandboxPolicy { network_policies: std::iter::once(( @@ -3416,6 +3530,108 @@ mod tests { assert_eq!(v2_env.get("GITHUB_TOKEN"), Some(&"ghp-test".to_string())); } + #[tokio::test] + async fn mint_sandbox_provider_token_rejects_unattached_provider() { + let state = test_server_state().await; + state + .store + .put_message(&test_microsoft_provider("work-microsoft")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-microsoft", + "microsoft", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + Vec::new(), + )) + .await + .unwrap(); + + let err = handle_mint_sandbox_provider_token( + &state, + Request::new(MintSandboxProviderTokenRequest { + sandbox_id: "sb-microsoft".to_string(), + provider_name: "work-microsoft".to_string(), + audience: "api://allowed".to_string(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("not attached")); + } + + #[tokio::test] + async fn mint_sandbox_provider_token_rejects_non_microsoft_provider_type() { + let state = test_server_state().await; + state + .store + .put_message(&test_provider("work-github", "github")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-provider-token", + "provider-token", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + vec!["work-github".to_string()], + )) + .await + .unwrap(); + + let err = handle_mint_sandbox_provider_token( + &state, + Request::new(MintSandboxProviderTokenRequest { + sandbox_id: "sb-provider-token".to_string(), + provider_name: "work-github".to_string(), + audience: "api://allowed".to_string(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("microsoft-agent-s2s")); + } + + #[tokio::test] + async fn mint_sandbox_provider_token_rejects_unallowed_audience() { + let state = test_server_state().await; + state + .store + .put_message(&test_microsoft_provider("work-microsoft")) + .await + .unwrap(); + state + .store + .put_message(&test_sandbox( + "sb-microsoft", + "microsoft", + test_policy_with_rule("sandbox_only", "sandbox.example.com"), + vec!["work-microsoft".to_string()], + )) + .await + .unwrap(); + + let err = handle_mint_sandbox_provider_token( + &state, + Request::new(MintSandboxProviderTokenRequest { + sandbox_id: "sb-microsoft".to_string(), + provider_name: "work-microsoft".to_string(), + audience: "api://not-allowed".to_string(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("not allowed")); + } + #[tokio::test] async fn provider_env_revision_changes_when_attached_provider_record_changes() { use openshell_core::proto::GetSandboxProviderEnvironmentRequest; diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 88b3ea743..b8794f300 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -274,22 +274,59 @@ pub(super) async fn resolve_provider_environment( .map_err(|e| Status::internal(format!("failed to fetch provider '{name}': {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; - for (key, value) in &provider.credentials { - if is_valid_env_key(key) { - env.entry(key.clone()).or_insert_with(|| value.clone()); - } else { - warn!( - provider_name = %name, - key = %key, - "skipping credential with invalid env var key" - ); + if provider.r#type.trim() == "microsoft-agent-s2s" { + insert_supervisor_env( + &mut env, + name, + "OPENSHELL_MICROSOFT_AGENT_S2S_PROVIDER_NAME", + name, + ); + for key in [ + "AZURE_TENANT_ID", + "A365_BLUEPRINT_CLIENT_ID", + "A365_BLUEPRINT_CLIENT_SECRET", + "A365_RUNTIME_AGENT_ID", + "A365_ALLOWED_AUDIENCES", + "A365_OBSERVABILITY_RESOURCE", + "A365_REQUIRED_ROLES", + ] { + if let Some(value) = provider + .credentials + .get(key) + .or_else(|| provider.config.get(key)) + { + insert_supervisor_env(&mut env, name, key, value); + } } + continue; + } + + for (key, value) in &provider.credentials { + insert_supervisor_env(&mut env, name, key, value); } } Ok(env) } +fn insert_supervisor_env( + env: &mut std::collections::HashMap, + provider_name: &str, + key: &str, + value: &str, +) { + if is_valid_env_key(key) { + env.entry(key.to_string()) + .or_insert_with(|| value.to_string()); + } else { + warn!( + provider_name = %provider_name, + key = %key, + "skipping credential with invalid env var key" + ); + } +} + pub(super) fn is_valid_env_key(key: &str) -> bool { let mut bytes = key.bytes(); let Some(first) = bytes.next() else { diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index a6e337dec..56f15bf5e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -30,6 +30,7 @@ mod inference; mod multiplex; mod persistence; pub(crate) mod policy_store; +mod provider_auth; mod sandbox_index; mod sandbox_watch; mod service_routing; @@ -103,6 +104,9 @@ pub struct ServerState { /// OIDC JWKS cache for JWT validation. `None` when OIDC is not configured. pub oidc_cache: Option>, + + /// Gateway-owned Microsoft S2S broker registry keyed by provider record. + pub microsoft_s2s_brokers: provider_auth::microsoft_s2s::BrokerRegistry, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -147,6 +151,7 @@ impl ServerState { settings_mutex: tokio::sync::Mutex::new(()), supervisor_sessions, oidc_cache, + microsoft_s2s_brokers: provider_auth::microsoft_s2s::BrokerRegistry::default(), } } } diff --git a/crates/openshell-server/src/provider_auth/microsoft_s2s.rs b/crates/openshell-server/src/provider_auth/microsoft_s2s.rs new file mode 100644 index 000000000..d7be2bcc6 --- /dev/null +++ b/crates/openshell-server/src/provider_auth/microsoft_s2s.rs @@ -0,0 +1,1086 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![allow(dead_code)] + +use base64::Engine as _; +use openshell_core::proto::Provider; +use reqwest::StatusCode; +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tokio::sync::Mutex; +use url::Url; + +const AZURE_TOKEN_EXCHANGE_SCOPE: &str = "api://AzureADTokenExchange/.default"; +const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; +const DEFAULT_AUTHORITY_HOST: &str = "https://login.microsoftonline.com"; +const DEFAULT_REFRESH_SKEW: Duration = Duration::from_secs(300); + +#[derive(Debug, thiserror::Error)] +pub enum MicrosoftS2sError { + #[error("invalid Microsoft S2S provider config: {0}")] + InvalidConfig(String), + #[error("audience '{0}' is not allowed by provider config")] + AudienceDenied(String), + #[error("failed to build token endpoint URL: {0}")] + Url(String), + #[error("Microsoft token request failed with HTTP {status}: {body}")] + TokenHttp { status: StatusCode, body: String }, + #[error("Microsoft token request failed: {0}")] + TokenTransport(String), + #[error("Microsoft token response did not include an access token")] + MissingAccessToken, + #[error("Microsoft token claim validation failed: {0}")] + ClaimValidation(String), +} + +#[derive(Debug, Clone, Default)] +pub struct BrokerRegistry { + inner: Arc>>, +} + +#[derive(Debug, Clone)] +struct BrokerRegistryEntry { + fingerprint: String, + broker: MicrosoftS2sBroker, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MicrosoftS2sConfig { + pub tenant_id: String, + pub blueprint_client_id: String, + pub blueprint_client_secret: String, + pub runtime_agent_id: String, + pub allowed_audiences: Vec, + pub observability_resource: Option, + pub required_roles: Vec, +} + +impl MicrosoftS2sConfig { + pub fn from_provider_maps( + credentials: &HashMap, + config: &HashMap, + ) -> Result { + let provider_value = |key: &str| { + credentials + .get(key) + .or_else(|| config.get(key)) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + }; + + let allowed_audiences = provider_value("A365_ALLOWED_AUDIENCES") + .map(|value| split_csv(&value)) + .unwrap_or_default(); + let required_roles = provider_value("A365_REQUIRED_ROLES") + .map(|value| split_csv(&value)) + .unwrap_or_default(); + + let cfg = Self { + tenant_id: provider_value("AZURE_TENANT_ID").unwrap_or_default(), + blueprint_client_id: provider_value("A365_BLUEPRINT_CLIENT_ID").unwrap_or_default(), + blueprint_client_secret: provider_value("A365_BLUEPRINT_CLIENT_SECRET") + .unwrap_or_default(), + runtime_agent_id: provider_value("A365_RUNTIME_AGENT_ID").unwrap_or_default(), + allowed_audiences, + observability_resource: provider_value("A365_OBSERVABILITY_RESOURCE"), + required_roles, + }; + cfg.validate()?; + Ok(cfg) + } + + pub fn validate(&self) -> Result<(), MicrosoftS2sError> { + require_non_empty("AZURE_TENANT_ID", &self.tenant_id)?; + require_non_empty("A365_BLUEPRINT_CLIENT_ID", &self.blueprint_client_id)?; + require_non_empty( + "A365_BLUEPRINT_CLIENT_SECRET", + &self.blueprint_client_secret, + )?; + require_non_empty("A365_RUNTIME_AGENT_ID", &self.runtime_agent_id)?; + + if self.allowed_audiences.is_empty() && self.observability_resource.is_none() { + return Err(MicrosoftS2sError::InvalidConfig( + "at least one allowed audience or observability resource is required".to_string(), + )); + } + + Ok(()) + } + + fn allowed_audience_set(&self) -> BTreeSet { + let mut allowed = self + .allowed_audiences + .iter() + .map(|audience| normalize_audience(audience)) + .filter(|audience| !audience.is_empty()) + .collect::>(); + if let Some(resource) = &self.observability_resource { + let normalized = normalize_audience(resource); + if !normalized.is_empty() { + allowed.insert(normalized); + } + } + allowed + } + + pub fn default_audience(&self) -> Option { + self.observability_resource + .clone() + .or_else(|| match self.allowed_audiences.as_slice() { + [only] => Some(only.clone()), + _ => None, + }) + } +} + +#[derive(Debug, Clone)] +pub struct MicrosoftS2sBrokerOptions { + pub authority_host: Url, + pub refresh_skew: Duration, +} + +impl Default for MicrosoftS2sBrokerOptions { + fn default() -> Self { + Self { + authority_host: Url::parse(DEFAULT_AUTHORITY_HOST) + .expect("default authority host should parse"), + refresh_skew: DEFAULT_REFRESH_SKEW, + } + } +} + +#[derive(Clone, Debug)] +pub struct MicrosoftS2sBroker { + config: Arc, + client: reqwest::Client, + authority_host: Url, + refresh_skew: Duration, + cache: Arc>>, +} + +impl MicrosoftS2sBroker { + pub fn new(config: MicrosoftS2sConfig) -> Result { + Self::with_options(config, MicrosoftS2sBrokerOptions::default()) + } + + pub fn with_options( + config: MicrosoftS2sConfig, + options: MicrosoftS2sBrokerOptions, + ) -> Result { + config.validate()?; + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(30)) + .build() + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + Ok(Self { + config: Arc::new(config), + client, + authority_host: options.authority_host, + refresh_skew: options.refresh_skew, + cache: Arc::new(Mutex::new(HashMap::new())), + }) + } + + pub async fn authorization_header( + &self, + audience: &str, + ) -> Result { + let token = self.access_token(audience).await?; + Ok(AuthorizationHeader { + value: format!("Bearer {}", token.access_token), + expires_at_unix: token.expires_at_unix, + cache_hit: token.cache_hit, + }) + } + + pub async fn access_token( + &self, + audience: &str, + ) -> Result { + let audience = normalize_audience(audience); + self.ensure_allowed_audience(&audience)?; + + let cache_key = CacheKey { + tenant_id: self.config.tenant_id.clone(), + runtime_agent_id: self.config.runtime_agent_id.clone(), + audience: audience.clone(), + }; + + if let Some(cached) = self.cached_token(&cache_key).await { + return Ok(BrokeredAccessToken { + access_token: cached.access_token, + expires_at_unix: cached.expires_at_unix, + cache_hit: true, + }); + } + + let assertion = self.fetch_blueprint_assertion().await?; + let token = self + .fetch_runtime_agent_token(&audience, &assertion) + .await?; + self.validate_runtime_token_claims(&audience, &token.access_token)?; + + let expires_at = token.expires_at(self.refresh_skew); + let expires_at_unix = token.expires_at_unix(); + let cached = CachedToken { + access_token: token.access_token, + expires_at, + expires_at_unix, + }; + self.cache.lock().await.insert(cache_key, cached.clone()); + + Ok(BrokeredAccessToken { + access_token: cached.access_token, + expires_at_unix: cached.expires_at_unix, + cache_hit: false, + }) + } + + pub fn default_audience(&self) -> Option { + self.config.default_audience() + } + + pub async fn evict(&self, audience: &str) { + let cache_key = CacheKey { + tenant_id: self.config.tenant_id.clone(), + runtime_agent_id: self.config.runtime_agent_id.clone(), + audience: normalize_audience(audience), + }; + self.cache.lock().await.remove(&cache_key); + } + + fn ensure_allowed_audience(&self, audience: &str) -> Result<(), MicrosoftS2sError> { + let allowed = self.config.allowed_audience_set(); + if allowed.contains(audience) { + Ok(()) + } else { + Err(MicrosoftS2sError::AudienceDenied(audience.to_string())) + } + } + + async fn cached_token(&self, cache_key: &CacheKey) -> Option { + let cached = self.cache.lock().await.get(cache_key).cloned()?; + if Instant::now() < cached.expires_at { + Some(cached) + } else { + None + } + } + + async fn fetch_blueprint_assertion(&self) -> Result { + let endpoint = self.token_endpoint()?; + let form = [ + ("grant_type", "client_credentials"), + ("client_id", self.config.blueprint_client_id.as_str()), + ( + "client_secret", + self.config.blueprint_client_secret.as_str(), + ), + ("scope", AZURE_TOKEN_EXCHANGE_SCOPE), + ("fmi_path", self.config.runtime_agent_id.as_str()), + ]; + self.post_token_form(endpoint, &form).await + } + + async fn fetch_runtime_agent_token( + &self, + audience: &str, + assertion: &TokenResponse, + ) -> Result { + let endpoint = self.token_endpoint()?; + let scope = default_scope_for_audience(audience); + let form = [ + ("grant_type", "client_credentials"), + ("client_id", self.config.runtime_agent_id.as_str()), + ("client_assertion", assertion.access_token.as_str()), + ("client_assertion_type", CLIENT_ASSERTION_TYPE_JWT_BEARER), + ("scope", scope.as_str()), + ]; + self.post_token_form(endpoint, &form).await + } + + async fn post_token_form( + &self, + endpoint: Url, + form: &[(&str, &str)], + ) -> Result { + let response = self + .client + .post(endpoint) + .form(form) + .send() + .await + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + let status = response.status(); + let body = response + .text() + .await + .map_err(|e| MicrosoftS2sError::TokenTransport(e.to_string()))?; + + if !status.is_success() { + return Err(MicrosoftS2sError::TokenHttp { + status, + body: sanitize_error_body(&body), + }); + } + + let parsed = serde_json::from_str::(&body).map_err(|e| { + MicrosoftS2sError::TokenTransport(format!("failed to parse token response: {e}")) + })?; + if parsed.access_token.trim().is_empty() { + return Err(MicrosoftS2sError::MissingAccessToken); + } + Ok(parsed) + } + + fn token_endpoint(&self) -> Result { + self.authority_host + .join(&format!( + "{}/oauth2/v2.0/token", + self.config.tenant_id.trim_matches('/') + )) + .map_err(|e| MicrosoftS2sError::Url(e.to_string())) + } + + fn validate_runtime_token_claims( + &self, + audience: &str, + token: &str, + ) -> Result<(), MicrosoftS2sError> { + let claims = JwtClaims::decode_unverified(token)?; + claims.expect_audience(audience)?; + claims.expect_tenant(&self.config.tenant_id)?; + claims.expect_runtime_agent(&self.config.runtime_agent_id)?; + claims.expect_app_token()?; + claims.expect_roles(&self.config.required_roles)?; + claims.expect_not_expired()?; + Ok(()) + } +} + +impl BrokerRegistry { + pub async fn broker_for_provider( + &self, + provider_name: &str, + provider: &Provider, + ) -> Result { + let fingerprint = provider_fingerprint(provider); + + { + let inner = self.inner.lock().await; + if let Some(entry) = inner.get(provider_name) + && entry.fingerprint == fingerprint + { + return Ok(entry.broker.clone()); + } + } + + let config = + MicrosoftS2sConfig::from_provider_maps(&provider.credentials, &provider.config)?; + let broker = MicrosoftS2sBroker::new(config)?; + + let mut inner = self.inner.lock().await; + inner.insert( + provider_name.to_string(), + BrokerRegistryEntry { + fingerprint, + broker: broker.clone(), + }, + ); + Ok(broker) + } +} + +fn provider_fingerprint(provider: &Provider) -> String { + use sha2::{Digest, Sha256}; + + let mut hasher = Sha256::new(); + hasher.update(provider.r#type.as_bytes()); + + let mut credential_items: Vec<_> = provider.credentials.iter().collect(); + credential_items.sort_by(|(left, _), (right, _)| left.cmp(right)); + for (key, value) in credential_items { + hasher.update(key.as_bytes()); + hasher.update([0]); + hasher.update(value.as_bytes()); + hasher.update([0xff]); + } + + let mut config_items: Vec<_> = provider.config.iter().collect(); + config_items.sort_by(|(left, _), (right, _)| left.cmp(right)); + for (key, value) in config_items { + hasher.update(key.as_bytes()); + hasher.update([0]); + hasher.update(value.as_bytes()); + hasher.update([0xff]); + } + + hex::encode(hasher.finalize()) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthorizationHeader { + pub value: String, + pub expires_at_unix: Option, + pub cache_hit: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BrokeredAccessToken { + pub access_token: String, + pub expires_at_unix: Option, + pub cache_hit: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct CacheKey { + tenant_id: String, + runtime_agent_id: String, + audience: String, +} + +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + expires_at: Instant, + expires_at_unix: Option, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: Option, + #[serde(default)] + expires_on: Option, +} + +impl TokenResponse { + fn expires_at(&self, refresh_skew: Duration) -> Instant { + let ttl = self.expires_in.unwrap_or(3600); + let ttl = Duration::from_secs(ttl); + Instant::now() + ttl.saturating_sub(refresh_skew) + } + + fn expires_at_unix(&self) -> Option { + if let Some(expires_on) = &self.expires_on + && let Ok(value) = expires_on.parse::() + { + return Some(value); + } + let expires_in = self.expires_in?; + let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); + Some(now.saturating_add(expires_in)) + } +} + +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +struct JwtClaims { + #[serde(default)] + aud: AudienceClaim, + #[serde(default)] + tid: Option, + #[serde(default)] + azp: Option, + #[serde(default)] + appid: Option, + #[serde(default)] + oid: Option, + #[serde(default)] + sub: Option, + #[serde(default)] + idtyp: Option, + #[serde(default)] + roles: Vec, + #[serde(default)] + nbf: Option, + #[serde(default)] + exp: Option, +} + +impl JwtClaims { + fn decode_unverified(token: &str) -> Result { + let mut parts = token.split('.'); + let _header = parts.next(); + let payload = parts + .next() + .ok_or_else(|| MicrosoftS2sError::ClaimValidation("token is not a JWT".to_string()))?; + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(payload) + .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload)) + .map_err(|e| { + MicrosoftS2sError::ClaimValidation(format!("token payload decode failed: {e}")) + })?; + serde_json::from_slice(&decoded).map_err(|e| { + MicrosoftS2sError::ClaimValidation(format!("token payload parse failed: {e}")) + }) + } + + fn expect_audience(&self, audience: &str) -> Result<(), MicrosoftS2sError> { + if self + .aud + .values() + .iter() + .any(|actual| normalize_audience(actual) == audience) + { + Ok(()) + } else { + Err(MicrosoftS2sError::ClaimValidation(format!( + "audience claim does not include '{audience}'" + ))) + } + } + + fn expect_tenant(&self, tenant_id: &str) -> Result<(), MicrosoftS2sError> { + match self.tid.as_deref() { + Some(actual) if actual.eq_ignore_ascii_case(tenant_id) => Ok(()), + Some(actual) => Err(MicrosoftS2sError::ClaimValidation(format!( + "tenant claim '{actual}' does not match expected tenant" + ))), + None => Err(MicrosoftS2sError::ClaimValidation( + "missing tenant claim".to_string(), + )), + } + } + + fn expect_runtime_agent(&self, runtime_agent_id: &str) -> Result<(), MicrosoftS2sError> { + let expected = runtime_agent_id.to_ascii_lowercase(); + let matches = [&self.azp, &self.appid, &self.oid, &self.sub] + .into_iter() + .flatten() + .any(|value| value.to_ascii_lowercase() == expected); + if matches { + Ok(()) + } else { + Err(MicrosoftS2sError::ClaimValidation( + "token does not represent the runtime agent identity".to_string(), + )) + } + } + + fn expect_app_token(&self) -> Result<(), MicrosoftS2sError> { + match self.idtyp.as_deref() { + Some("app") => Ok(()), + Some(actual) => Err(MicrosoftS2sError::ClaimValidation(format!( + "expected app token, got idtyp='{actual}'" + ))), + None => Err(MicrosoftS2sError::ClaimValidation( + "missing idtyp claim".to_string(), + )), + } + } + + fn expect_roles(&self, required_roles: &[String]) -> Result<(), MicrosoftS2sError> { + for required in required_roles { + if !self.roles.iter().any(|role| role == required) { + return Err(MicrosoftS2sError::ClaimValidation(format!( + "missing required role '{required}'" + ))); + } + } + Ok(()) + } + + fn expect_not_expired(&self) -> Result<(), MicrosoftS2sError> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| MicrosoftS2sError::ClaimValidation(e.to_string()))? + .as_secs(); + if let Some(nbf) = self.nbf + && now.saturating_add(60) < nbf + { + return Err(MicrosoftS2sError::ClaimValidation( + "token is not valid yet".to_string(), + )); + } + if let Some(exp) = self.exp + && exp <= now.saturating_sub(60) + { + return Err(MicrosoftS2sError::ClaimValidation( + "token is expired".to_string(), + )); + } + Ok(()) + } +} + +#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +#[serde(untagged)] +enum AudienceClaim { + One(String), + Many(Vec), + #[default] + Missing, +} + +impl AudienceClaim { + fn values(&self) -> Vec<&str> { + match self { + Self::One(value) => vec![value.as_str()], + Self::Many(values) => values.iter().map(String::as_str).collect(), + Self::Missing => Vec::new(), + } + } +} + +fn require_non_empty(name: &str, value: &str) -> Result<(), MicrosoftS2sError> { + if value.trim().is_empty() { + Err(MicrosoftS2sError::InvalidConfig(format!( + "{name} is required" + ))) + } else { + Ok(()) + } +} + +fn normalize_audience(input: &str) -> String { + input + .trim() + .trim_end_matches("/.default") + .trim_end_matches('/') + .to_string() +} + +fn split_csv(value: &str) -> Vec { + value + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToString::to_string) + .collect() +} + +fn default_scope_for_audience(audience: &str) -> String { + format!("{}/.default", normalize_audience(audience)) +} + +fn sanitize_error_body(body: &str) -> String { + const MAX_ERROR_BODY: usize = 1024; + body.chars() + .filter(|ch| !ch.is_control() || *ch == '\n' || *ch == '\t') + .take(MAX_ERROR_BODY) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use base64::engine::general_purpose::URL_SAFE_NO_PAD; + use std::net::SocketAddr; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + const TENANT: &str = "11111111-1111-4111-8111-111111111111"; + const BLUEPRINT: &str = "22222222-2222-4222-8222-222222222222"; + const RUNTIME_AGENT: &str = "33333333-3333-4333-8333-333333333333"; + const RESOURCE: &str = "api://44444444-4444-4444-8444-444444444444"; + + fn config() -> MicrosoftS2sConfig { + MicrosoftS2sConfig { + tenant_id: TENANT.to_string(), + blueprint_client_id: BLUEPRINT.to_string(), + blueprint_client_secret: "secret".to_string(), + runtime_agent_id: RUNTIME_AGENT.to_string(), + allowed_audiences: vec![RESOURCE.to_string()], + observability_resource: None, + required_roles: vec!["Agent365.Observability.OtelWrite".to_string()], + } + } + + fn broker(server: &FakeTokenServer) -> MicrosoftS2sBroker { + MicrosoftS2sBroker::with_options( + config(), + MicrosoftS2sBrokerOptions { + authority_host: Url::parse(&server.uri()).expect("fake server URL"), + refresh_skew: Duration::from_secs(60), + }, + ) + .expect("broker") + } + + fn jwt(claims: serde_json::Value) -> String { + let header = serde_json::json!({"alg": "none", "typ": "JWT"}); + let header = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).expect("header json")); + let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims).expect("claims json")); + format!("{header}.{payload}.signature") + } + + fn runtime_token() -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock") + .as_secs(); + jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": RUNTIME_AGENT, + "oid": RUNTIME_AGENT, + "sub": RUNTIME_AGENT, + "idtyp": "app", + "roles": ["Agent365.Observability.OtelWrite"], + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })) + } + + #[test] + fn default_audience_prefers_observability_resource() { + let cfg = MicrosoftS2sConfig { + observability_resource: Some("api://observability".to_string()), + ..config() + }; + assert_eq!( + cfg.default_audience().as_deref(), + Some("api://observability") + ); + } + + #[test] + fn default_audience_uses_only_allowed_audience() { + let cfg = MicrosoftS2sConfig { + allowed_audiences: vec![RESOURCE.to_string()], + observability_resource: None, + ..config() + }; + assert_eq!(cfg.default_audience().as_deref(), Some(RESOURCE)); + } + + #[test] + fn builds_config_from_provider_maps() { + let credentials = HashMap::from([ + ("AZURE_TENANT_ID".to_string(), TENANT.to_string()), + ( + "A365_BLUEPRINT_CLIENT_SECRET".to_string(), + "secret".to_string(), + ), + ]); + let config = HashMap::from([ + ( + "A365_BLUEPRINT_CLIENT_ID".to_string(), + BLUEPRINT.to_string(), + ), + ( + "A365_RUNTIME_AGENT_ID".to_string(), + RUNTIME_AGENT.to_string(), + ), + ( + "A365_ALLOWED_AUDIENCES".to_string(), + format!("{RESOURCE}, api://extra/.default"), + ), + ( + "A365_REQUIRED_ROLES".to_string(), + "Agent365.Observability.OtelWrite".to_string(), + ), + ]); + + let cfg = MicrosoftS2sConfig::from_provider_maps(&credentials, &config) + .expect("provider maps should build config"); + + assert_eq!(cfg.tenant_id, TENANT); + assert_eq!(cfg.blueprint_client_id, BLUEPRINT); + assert_eq!(cfg.blueprint_client_secret, "secret"); + assert_eq!(cfg.runtime_agent_id, RUNTIME_AGENT); + assert_eq!( + cfg.allowed_audiences, + vec![RESOURCE.to_string(), "api://extra/.default".to_string()] + ); + assert_eq!( + cfg.required_roles, + vec!["Agent365.Observability.OtelWrite".to_string()] + ); + } + + #[derive(Debug, Default)] + struct FakeTokenState { + runtime_token: Mutex, + blueprint_requests: Mutex, + runtime_requests: Mutex, + } + + #[derive(Clone)] + struct FakeTokenServer { + addr: SocketAddr, + state: Arc, + } + + impl FakeTokenServer { + async fn start(runtime_token: String) -> Self { + let state = Arc::new(FakeTokenState { + runtime_token: Mutex::new(runtime_token), + blueprint_requests: Mutex::new(0), + runtime_requests: Mutex::new(0), + }); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind fake token server"); + let addr = listener.local_addr().expect("fake token server addr"); + let server_state = state.clone(); + tokio::spawn(async move { + loop { + let Ok((stream, _peer)) = listener.accept().await else { + break; + }; + let state = server_state.clone(); + tokio::spawn(async move { + handle_token_connection(stream, state).await; + }); + } + }); + Self { addr, state } + } + + fn uri(&self) -> String { + format!("http://{}", self.addr) + } + + async fn request_counts(&self) -> (usize, usize) { + ( + *self.state.blueprint_requests.lock().await, + *self.state.runtime_requests.lock().await, + ) + } + } + + async fn handle_token_connection( + mut stream: tokio::net::TcpStream, + state: Arc, + ) { + let mut buffer = Vec::new(); + let mut temp = [0_u8; 1024]; + let mut content_length = None; + let mut header_end = None; + + loop { + let read = stream.read(&mut temp).await.expect("read fake request"); + if read == 0 { + return; + } + buffer.extend_from_slice(&temp[..read]); + if header_end.is_none() + && let Some(pos) = find_header_end(&buffer) + { + header_end = Some(pos); + let headers = String::from_utf8_lossy(&buffer[..pos]); + content_length = parse_content_length(&headers); + } + if let (Some(end), Some(len)) = (header_end, content_length) + && buffer.len() >= end + 4 + len + { + break; + } + } + + let end = header_end.expect("headers should be present"); + let len = content_length.expect("content length should be present"); + let body = &buffer[end + 4..end + 4 + len]; + let form = url::form_urlencoded::parse(body) + .into_owned() + .collect::>(); + let response = token_response_for_form(&state, &form).await; + stream + .write_all(response.as_bytes()) + .await + .expect("write fake response"); + } + + async fn token_response_for_form( + state: &Arc, + form: &HashMap, + ) -> String { + if form + .get("client_id") + .is_some_and(|value| value == BLUEPRINT) + { + assert_eq!( + form.get("grant_type").map(String::as_str), + Some("client_credentials") + ); + assert_eq!( + form.get("scope").map(String::as_str), + Some(AZURE_TOKEN_EXCHANGE_SCOPE) + ); + assert_eq!( + form.get("fmi_path").map(String::as_str), + Some(RUNTIME_AGENT) + ); + *state.blueprint_requests.lock().await += 1; + return json_response( + 200, + serde_json::json!({ + "token_type": "Bearer", + "expires_in": 3600, + "access_token": "blueprint-assertion" + }), + ); + } + + if form + .get("client_id") + .is_some_and(|value| value == RUNTIME_AGENT) + { + assert_eq!( + form.get("grant_type").map(String::as_str), + Some("client_credentials") + ); + assert_eq!( + form.get("client_assertion").map(String::as_str), + Some("blueprint-assertion") + ); + assert_eq!( + form.get("client_assertion_type").map(String::as_str), + Some(CLIENT_ASSERTION_TYPE_JWT_BEARER) + ); + let expected_scope = format!("{RESOURCE}/.default"); + assert_eq!( + form.get("scope").map(String::as_str), + Some(expected_scope.as_str()) + ); + *state.runtime_requests.lock().await += 1; + let runtime_token = state.runtime_token.lock().await.clone(); + return json_response( + 200, + serde_json::json!({ + "token_type": "Bearer", + "expires_in": 3600, + "access_token": runtime_token + }), + ); + } + + json_response( + 400, + serde_json::json!({"error": "unexpected token request"}), + ) + } + + fn find_header_end(buffer: &[u8]) -> Option { + buffer.windows(4).position(|window| window == b"\r\n\r\n") + } + + fn parse_content_length(headers: &str) -> Option { + headers.lines().find_map(|line| { + let (name, value) = line.split_once(':')?; + if name.eq_ignore_ascii_case("content-length") { + value.trim().parse().ok() + } else { + None + } + }) + } + + fn json_response(status: u16, body: serde_json::Value) -> String { + let reason = if status == 200 { "OK" } else { "Bad Request" }; + let body = body.to_string(); + format!( + "HTTP/1.1 {status} {reason}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) + } + + #[tokio::test] + #[ignore = "requires binding a local TCP listener"] + async fn mints_runtime_agent_token_with_two_step_exchange() { + let runtime_token = runtime_token(); + let server = FakeTokenServer::start(runtime_token.clone()).await; + + let token = broker(&server) + .access_token(RESOURCE) + .await + .expect("token should mint"); + + assert_eq!(token.access_token, runtime_token); + assert!(!token.cache_hit); + assert!(token.expires_at_unix.is_some()); + assert_eq!(server.request_counts().await, (1, 1)); + } + + #[tokio::test] + #[ignore = "requires binding a local TCP listener"] + async fn returns_cached_token_for_same_audience() { + let runtime_token = runtime_token(); + let server = FakeTokenServer::start(runtime_token.clone()).await; + let broker = broker(&server); + + let first = broker.access_token(RESOURCE).await.expect("first token"); + let second = broker.access_token(RESOURCE).await.expect("cached token"); + + assert_eq!(first.access_token, second.access_token); + assert!(!first.cache_hit); + assert!(second.cache_hit); + assert_eq!(server.request_counts().await, (1, 1)); + } + + #[tokio::test] + async fn rejects_unallowed_audience_before_network_call() { + let broker = MicrosoftS2sBroker::new(config()).expect("broker"); + let err = broker + .access_token("api://not-allowed") + .await + .expect_err("audience should be denied"); + + assert!(matches!(err, MicrosoftS2sError::AudienceDenied(_))); + } + + #[tokio::test] + #[ignore = "requires binding a local TCP listener"] + async fn validates_runtime_agent_claims() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock") + .as_secs(); + let wrong_agent_token = jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": "a185cf21-03c8-4bf1-919a-ec8f0782118d", + "idtyp": "app", + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })); + let server = FakeTokenServer::start(wrong_agent_token).await; + + let err = broker(&server) + .access_token(RESOURCE) + .await + .expect_err("wrong runtime agent should fail validation"); + + assert!(matches!(err, MicrosoftS2sError::ClaimValidation(_))); + assert!( + err.to_string().contains("runtime agent identity"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + #[ignore = "requires binding a local TCP listener"] + async fn validates_required_roles() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock") + .as_secs(); + let missing_role_token = jwt(serde_json::json!({ + "aud": RESOURCE, + "tid": TENANT, + "azp": RUNTIME_AGENT, + "oid": RUNTIME_AGENT, + "sub": RUNTIME_AGENT, + "idtyp": "app", + "roles": ["Other.Role"], + "nbf": now.saturating_sub(30), + "exp": now + 3600 + })); + let server = FakeTokenServer::start(missing_role_token).await; + + let err = broker(&server) + .access_token(RESOURCE) + .await + .expect_err("missing required role should fail validation"); + + assert!(matches!(err, MicrosoftS2sError::ClaimValidation(_))); + assert!( + err.to_string().contains("missing required role"), + "unexpected error: {err}" + ); + } +} diff --git a/crates/openshell-server/src/provider_auth/mod.rs b/crates/openshell-server/src/provider_auth/mod.rs new file mode 100644 index 000000000..bbe4de808 --- /dev/null +++ b/crates/openshell-server/src/provider_auth/mod.rs @@ -0,0 +1,4 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod microsoft_s2s; diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index 59c2a23f6..678527897 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -497,6 +497,18 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _: tonic::Request, + ) -> Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 73ad0aff0..508340172 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -162,6 +162,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index 14a63c566..c59b15e94 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -121,6 +121,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index 00ed1657f..610266e34 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -134,6 +134,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index d82c9c261..41b028ae6 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -172,6 +172,12 @@ impl OpenShell for RelayGateway { { Err(Status::unimplemented("unused")) } + async fn mint_sandbox_provider_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn create_ssh_session( &self, _: tonic::Request, diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 277cffb51..6db71a593 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -157,6 +157,15 @@ impl OpenShell for TestOpenShell { )) } + async fn mint_sandbox_provider_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + openshell_core::proto::MintSandboxProviderTokenResponse::default(), + )) + } + async fn create_ssh_session( &self, _request: tonic::Request, diff --git a/proto/openshell.proto b/proto/openshell.proto index e4a1b0673..c5ad91f9f 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -137,6 +137,10 @@ service OpenShell { rpc GetSandboxProviderEnvironment(GetSandboxProviderEnvironmentRequest) returns (GetSandboxProviderEnvironmentResponse); + // Mint a short-lived provider-backed token for a running sandbox. + rpc MintSandboxProviderToken(MintSandboxProviderTokenRequest) + returns (MintSandboxProviderTokenResponse); + // Fetch recent sandbox logs (one-shot). rpc GetSandboxLogs(GetSandboxLogsRequest) returns (GetSandboxLogsResponse); @@ -906,6 +910,28 @@ message GetSandboxProviderEnvironmentResponse { uint64 provider_env_revision = 2; } +// Mint a short-lived provider-backed token for a running sandbox. +message MintSandboxProviderTokenRequest { + // The sandbox ID. + string sandbox_id = 1; + // The attached provider name. + string provider_name = 2; + // The requested resource audience. + string audience = 3; +} + +// Brokered provider token response. +message MintSandboxProviderTokenResponse { + // The access token value. + string access_token = 1; + // Token type, typically "Bearer". + string token_type = 2; + // Expiry as Unix epoch seconds. + uint64 expires_at_unix = 3; + // True when the broker cache served the token without a mint call. + bool cache_hit = 4; +} + // --------------------------------------------------------------------------- // Policy update messages // --------------------------------------------------------------------------- diff --git a/providers/microsoft-agent-s2s.yaml b/providers/microsoft-agent-s2s.yaml new file mode 100644 index 000000000..c6ed14ff5 --- /dev/null +++ b/providers/microsoft-agent-s2s.yaml @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +id: microsoft-agent-s2s +display_name: Microsoft Agent S2S +description: Microsoft Agent service-to-service provider record without managed policy defaults +category: agent +credentials: + - name: blueprint_client_secret + description: Microsoft blueprint client secret + env_vars: [A365_BLUEPRINT_CLIENT_SECRET] + required: true