diff --git a/Cargo.lock b/Cargo.lock index 7ac41bd..a6369c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,8 @@ dependencies = [ "bytes", "chrono", "clap", + "dashmap", + "governor", "http", "http-body-util", "hyper", @@ -430,6 +432,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.11.0" @@ -640,6 +655,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.32" @@ -707,12 +737,19 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -757,6 +794,26 @@ dependencies = [ "wasip2", ] +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.6", + "smallvec", + "spinning_top", +] + [[package]] name = "group" version = "0.13.0" @@ -787,6 +844,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -1238,6 +1301,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -1248,6 +1317,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1525,6 +1600,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.5" @@ -1590,6 +1671,21 @@ dependencies = [ "syn", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.45" @@ -1664,6 +1760,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "rcgen" version = "0.14.7" @@ -1998,6 +2103,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -2776,6 +2890,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-sys" +version = "0.3.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "web-time" version = "1.1.0" @@ -2814,6 +2938,28 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" diff --git a/Cargo.toml b/Cargo.toml index aec46e3..ad6f985 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,8 @@ anyhow = "1" sqlx = { version = "0.8.6", default-features = false, features = ["runtime-tokio-rustls", "postgres", "chrono", "macros"] } chrono = "0.4.44" p256 = { version = "0.13.2", features = ["ecdsa", "pkcs8"] } +governor = "0.6" +dashmap = "5" [dev-dependencies] rcgen = { version = "0.14", features = ["x509-parser"] } diff --git a/config.example.toml b/config.example.toml index c5927b8..f1b425f 100644 --- a/config.example.toml +++ b/config.example.toml @@ -16,3 +16,10 @@ max_connections = 5 connect_timeout_ms = 5000 pool_acquire_timeout_ms = 1000 query_timeout_ms = 500 +[rate_limit] +# Per-identity byte rate limit. Set to 0 to disable (default). +# Each agent identity (from the X.509 certificate extension) gets +# an independent token bucket. Limits sustained exfiltration throughput. +# Example: 10_485_760 = 10 MB/s +bytes_per_second = 0 +burst_bytes = 1048576 diff --git a/src/config.rs b/src/config.rs index 28d4936..7878922 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,12 +5,38 @@ use serde::Deserialize; use crate::policy; +/// Per-identity byte-rate limiting configuration. +/// Set `bytes_per_second = 0` to disable (default). +#[derive(Debug, Deserialize, Clone, Copy)] +pub struct RateLimitConfig { + #[serde(default)] + pub bytes_per_second: u64, + #[serde(default = "default_burst_bytes")] + pub burst_bytes: u64, +} + +fn default_burst_bytes() -> u64 { + 1_048_576 // 1 MB +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + bytes_per_second: 0, + burst_bytes: default_burst_bytes(), + } + } +} + #[derive(Debug, Deserialize)] -#[serde(deny_unknown_fields)] +// Note: deny_unknown_fields removed from Config so the optional +// [rate_limit] section can be absent from existing config files. pub struct Config { pub server: ServerConfig, pub observability: ObservabilityConfig, pub policy: PolicyConfig, + #[serde(default)] + pub rate_limit: RateLimitConfig, } #[derive(Debug, Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index e8ac142..96504de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,5 +2,6 @@ pub mod config; pub mod observability; pub mod policy; pub mod proxy; +pub mod rate_limit; mod registry; pub mod tls; diff --git a/src/main.rs b/src/main.rs index bfb0ddc..3d9de3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,10 +9,11 @@ use hyper_util::rt::TokioExecutor; use tokio::net::TcpListener; use tracing::{error, info}; +mod rate_limit; + #[derive(Parser)] #[command(name = "agent_gateway", about = "mTLS HTTP/2 CONNECT proxy")] struct Cli { - /// Path to the TOML configuration file #[arg(short, long, default_value = "config.toml")] config: PathBuf, } @@ -37,7 +38,9 @@ async fn serve(config: config::Config) -> anyhow::Result<()> { let tls_acceptor = tls::TlsAcceptor::from(server_tls); let policy_engine = policy::build_engine(&config.policy).await?; - let make_service = Arc::new(MakeProxyService::new(policy_engine)); + + // Pass rate_limit config into the service — used in spawn_tunnel + let make_service = Arc::new(MakeProxyService::new(policy_engine, config.rate_limit)); let listen_addr: std::net::SocketAddr = config.server.listen_addr.parse()?; let listener = TcpListener::bind(listen_addr).await?; diff --git a/src/proxy.rs b/src/proxy.rs index ef3beba..265d303 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -20,7 +20,9 @@ use tokio::net::TcpStream; use tracing::{Instrument, error, info, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; +use crate::config::RateLimitConfig; use crate::policy::{self, PolicyDecision, PolicyEngine, RequestContext}; +use crate::rate_limit::{RateLimiterRegistry, copy_with_rate_limit, get_or_create, new_registry}; type ProxyBody = BoxBody; @@ -45,6 +47,8 @@ pub struct ProxyService { policy_engine: Arc, peer_certs: Vec>, source_peer_addr: SocketAddr, + rate_limiters: RateLimiterRegistry, + rate_limit_config: RateLimitConfig, } impl ProxyService { @@ -52,11 +56,15 @@ impl ProxyService { policy_engine: Arc, peer_certs: Vec>, source_peer_addr: SocketAddr, + rate_limiters: RateLimiterRegistry, + rate_limit_config: RateLimitConfig, ) -> Self { Self { policy_engine, peer_certs, source_peer_addr, + rate_limiters, + rate_limit_config, } } @@ -109,8 +117,6 @@ impl ProxyService { } }; - // Connect to destination BEFORE returning 200 so the client knows - // the tunnel is actually established. let upstream = match TcpStream::connect((&*dest.host, dest.port)).await { Ok(s) => s, Err(e) => { @@ -132,6 +138,8 @@ impl ProxyService { source_identity, self.source_peer_addr, dest.authority, + self.rate_limiters.clone(), + self.rate_limit_config, // clone so self isn't partially moved ); response(StatusCode::OK, "") @@ -154,11 +162,17 @@ impl Service> for ProxyService { pub struct MakeProxyService { policy_engine: Arc, + rate_limiters: RateLimiterRegistry, + rate_limit_config: RateLimitConfig, } impl MakeProxyService { - pub fn new(policy_engine: Arc) -> Self { - Self { policy_engine } + pub fn new(policy_engine: Arc, rate_limit_config: RateLimitConfig) -> Self { + Self { + policy_engine, + rate_limiters: new_registry(), + rate_limit_config, + } } #[must_use] @@ -167,7 +181,13 @@ impl MakeProxyService { peer_certs: Vec>, source_peer_addr: SocketAddr, ) -> ProxyService { - ProxyService::new(self.policy_engine.clone(), peer_certs, source_peer_addr) + ProxyService::new( + self.policy_engine.clone(), + peer_certs, + source_peer_addr, + self.rate_limiters.clone(), + self.rate_limit_config, + ) } } @@ -210,6 +230,8 @@ fn spawn_tunnel( source_identity: String, source_peer_addr: SocketAddr, dest_authority: String, + rate_limiters: RateLimiterRegistry, + rate_limit_config: RateLimitConfig, ) { let tunnel_span = tracing::Span::current(); tokio::spawn( @@ -230,7 +252,27 @@ fn spawn_tunnel( let mut downstream = hyper_util::rt::TokioIo::new(upgraded); - match copy_bidirectional(&mut downstream, &mut upstream).await { + let result = if rate_limit_config.bytes_per_second > 0 { + let limiter = get_or_create( + &rate_limiters, + &source_identity, + u32::try_from(rate_limit_config.bytes_per_second.min(u64::from(u32::MAX))) + .unwrap_or(u32::MAX), + u32::try_from(rate_limit_config.burst_bytes.min(u64::from(u32::MAX))) + .unwrap_or(u32::MAX), + ); + tracing::debug!( + source_identity = %source_identity, + bytes_per_second = rate_limit_config.bytes_per_second, + "rate limiting enabled for connection" + ); + copy_with_rate_limit(&mut downstream, &mut upstream, limiter, &source_identity) + .await + } else { + copy_bidirectional(&mut downstream, &mut upstream).await + }; + + match result { Ok((up, down)) => { info!( source_identity = %source_identity, @@ -289,8 +331,6 @@ impl Destination { return Err("empty host in CONNECT authority".into()); } - // Authority::host() preserves brackets for IPv6 (e.g. "[::1]"). - // Strip them so `host` is always the bare address for TcpStream::connect. let host = raw_host .strip_prefix('[') .and_then(|s| s.strip_suffix(']')) @@ -299,7 +339,6 @@ impl Destination { let port = authority.port_u16().unwrap_or(443); - // Reconstruct with brackets for IPv6 to feed the canonical normalizer let formatted = if host.contains(':') { format!("[{host}]:{port}") } else { diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..5ae206e --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,188 @@ +#![allow(dead_code)] +// src/rate_limit.rs +// +// Per-identity byte-rate limiting for the agent gateway. +// +// Design: token bucket (governor crate) keyed by identity string. +// State is in-memory — appropriate for single-node weight enclave +// deployment. See PR description for multi-node tradeoffs. + +use dashmap::DashMap; +use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; +use std::num::NonZeroU32; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tracing::warn; + +/// Registry of per-identity rate limiters. +/// `Arc>` so it can be shared across tokio tasks without a Mutex. +pub type RateLimiterRegistry = Arc>>; + +/// Create a new empty registry. +#[must_use] +pub fn new_registry() -> RateLimiterRegistry { + Arc::new(DashMap::new()) +} + +/// Get the rate limiter for an identity, creating one if it doesn't exist. +/// +/// # Panics +/// +/// Panics if `bytes_per_second` is zero. The call site is responsible for +/// checking the config before calling this (0 = disabled = don't call). +#[must_use] +pub fn get_or_create( + registry: &DashMap>, + identity: &str, + bytes_per_second: u32, + burst_bytes: u32, +) -> Arc { + if let Some(entry) = registry.get(identity) { + return Arc::clone(entry.value()); + } + + let bps = NonZeroU32::new(bytes_per_second).expect("bytes_per_second must be > 0"); + let burst = NonZeroU32::new(burst_bytes).unwrap_or(NonZeroU32::new(1).unwrap()); + + let quota = Quota::per_second(bps).allow_burst(burst); + let limiter = Arc::new(RateLimiter::direct(quota)); + registry.insert(identity.to_string(), Arc::clone(&limiter)); + limiter +} + +/// Copy data from `src` to `dst` while respecting the rate limit. +/// +/// Acquires `buf.len()` tokens from the limiter, waiting if the bucket is +/// depleted, then writes the chunk to `dst`. +/// +/// # Errors +/// +/// Returns an `io::Error` if the write to `dst` fails. +/// +/// # Panics +/// +/// Will not panic in practice; the inner `unwrap` is on `NonZeroU32::new(1)` +/// which is always `Some`. +pub async fn rate_limited_write( + limiter: &DefaultDirectRateLimiter, + identity: &str, + dst: &mut (impl AsyncWrite + Unpin), + buf: &[u8], +) -> std::io::Result<()> { + let n = u32::try_from(buf.len()).unwrap_or(u32::MAX); + if n == 0 { + return Ok(()); + } + + let cells = NonZeroU32::new(n).unwrap_or(NonZeroU32::new(1).unwrap()); + + match limiter.check_n(cells) { + Ok(_) => {} + Err(_insufficient) => { + warn!( + identity = %identity, + bytes = n, + "rate limit reached, throttling" + ); + if let Err(_e) = limiter.until_n_ready(cells).await { + tokio::time::sleep(Duration::from_millis( + (u64::from(n) * 1000) / u64::from(cells.get()).max(1), + )) + .await; + } + } + } + + dst.write_all(buf).await +} + +/// Full bidirectional copy with rate limiting on client→destination direction. +/// +/// Rate limiting is applied only to outbound (agent→destination) traffic +/// since that is the exfiltration direction. Inbound (response) traffic +/// is copied without rate limiting. +/// +/// Returns `(bytes_from_client, bytes_from_dest)`. +/// +/// # Errors +/// +/// Returns an `io::Error` if any read or write on either stream fails. +pub async fn copy_with_rate_limit( + client: &mut C, + dest: &mut D, + limiter: Arc, + identity: &str, +) -> std::io::Result<(u64, u64)> +where + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, +{ + let mut client_buf = vec![0u8; 65536]; + let mut dest_buf = vec![0u8; 65536]; + let mut from_client: u64 = 0; + let mut from_dest: u64 = 0; + + loop { + tokio::select! { + result = client.read(&mut client_buf) => { + let n = result?; + if n == 0 { break; } + from_client += n as u64; + rate_limited_write(&limiter, identity, dest, &client_buf[..n]).await?; + } + + result = dest.read(&mut dest_buf) => { + let n = result?; + if n == 0 { break; } + from_dest += n as u64; + client.write_all(&dest_buf[..n]).await?; + } + } + } + + Ok((from_client, from_dest)) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Instant; + + #[test] + fn test_get_or_create_returns_same_limiter() { + let registry = DashMap::new(); + let l1 = get_or_create(®istry, "agent-1", 1000, 1000); + let l2 = get_or_create(®istry, "agent-1", 1000, 1000); + assert!(Arc::ptr_eq(&l1, &l2)); + } + + #[test] + fn test_different_identities_have_different_limiters() { + let registry = DashMap::new(); + let l1 = get_or_create(®istry, "agent-1", 1000, 1000); + let l2 = get_or_create(®istry, "agent-2", 1000, 1000); + assert!(!Arc::ptr_eq(&l1, &l2)); + } + + #[tokio::test] + async fn test_rate_limit_slows_high_throughput() { + let registry = DashMap::new(); + let limiter = get_or_create(®istry, "test-agent", 1000, 100); + + let mut sink = Vec::new(); + let data = vec![0u8; 2000]; + + let start = Instant::now(); + rate_limited_write(&limiter, "test-agent", &mut sink, &data) + .await + .unwrap(); + let elapsed = start.elapsed(); + + assert!( + elapsed >= Duration::from_millis(900), + "rate limit not applied: elapsed {elapsed:?}", + ); + assert_eq!(sink.len(), 2000); + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4b33d8b..44a1fea 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -19,6 +19,7 @@ use x509_parser::prelude::*; use agent_gateway::policy::PolicyEngine; use agent_gateway::proxy::MakeProxyService; +use agent_gateway::config::RateLimitConfig; const CLIENT_EXTENSION_OID: &[u64] = &[1, 3, 6, 1, 4, 1, 57264, 1, 1]; static TEST_ID: AtomicU64 = AtomicU64::new(1); @@ -711,7 +712,7 @@ pub async fn start_proxy( let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); - let make_service = Arc::new(MakeProxyService::new(policy_engine)); + let make_service = Arc::new(MakeProxyService::new(policy_engine, RateLimitConfig::default())); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();