diff --git a/config.example.toml b/config.example.toml index c5927b8..19267dd 100644 --- a/config.example.toml +++ b/config.example.toml @@ -16,3 +16,10 @@ max_connections = 5 connect_timeout_ms = 5000 pool_acquire_timeout_ms = 1000 query_timeout_ms = 500 + +# Optional per-identity data limit for proxied traffic. +# This is intentionally local to one gateway process. +[rate_limit] +enabled = false +window_secs = 60 +max_bytes_per_identity = 10485760 diff --git a/src/config.rs b/src/config.rs index 28d4936..748282c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,7 @@ use std::time::Duration; use serde::Deserialize; use crate::policy; +use crate::rate_limit::RateLimitConfig; #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] @@ -11,6 +12,7 @@ pub struct Config { pub server: ServerConfig, pub observability: ObservabilityConfig, pub policy: PolicyConfig, + pub rate_limit: Option, } #[derive(Debug, Deserialize)] @@ -40,6 +42,43 @@ pub struct PolicyConfig { pub query_timeout_ms: Option, } +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct RateLimitSection { + pub enabled: Option, + pub window_secs: Option, + pub max_bytes_per_identity: Option, +} + +impl RateLimitSection { + #[must_use] + pub fn to_runtime_config(&self) -> RateLimitConfig { + RateLimitConfig { + enabled: self.enabled.unwrap_or(false), + window_secs: self.window_secs.unwrap_or(60), + max_bytes_per_identity: self.max_bytes_per_identity.unwrap_or(10 * 1024 * 1024), + } + } + + fn validate(&self) -> anyhow::Result<()> { + if let Some(window_secs) = self.window_secs { + anyhow::ensure!( + window_secs > 0, + "rate_limit.window_secs must be greater than zero" + ); + } + + if let Some(max_bytes) = self.max_bytes_per_identity { + anyhow::ensure!( + max_bytes > 0, + "rate_limit.max_bytes_per_identity must be greater than zero" + ); + } + + Ok(()) + } +} + impl Config { /// Load, parse, and validate a TOML config file. /// @@ -62,6 +101,11 @@ impl Config { policy::parse_client_ext_oid(&self.policy.client_ext_oid)?; self.policy.validate()?; + + if let Some(rate_limit) = &self.rate_limit { + rate_limit.validate()?; + } + Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index e8ac142..96504de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,5 +2,6 @@ pub mod config; pub mod observability; pub mod policy; pub mod proxy; +pub mod rate_limit; mod registry; pub mod tls; diff --git a/src/main.rs b/src/main.rs index bfb0ddc..b616079 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::Arc; use agent_gateway::proxy::MakeProxyService; +use agent_gateway::rate_limit::{RateLimitConfig, RateLimiter}; use agent_gateway::{config, observability, policy, proxy, tls}; use anyhow::Context; use clap::Parser; @@ -37,7 +38,14 @@ async fn serve(config: config::Config) -> anyhow::Result<()> { let tls_acceptor = tls::TlsAcceptor::from(server_tls); let policy_engine = policy::build_engine(&config.policy).await?; - let make_service = Arc::new(MakeProxyService::new(policy_engine)); + let rate_limit_config = config + .rate_limit + .as_ref() + .map_or_else(RateLimitConfig::disabled, |section| { + section.to_runtime_config() + }); + let rate_limiter = Arc::new(RateLimiter::new(rate_limit_config)); + let make_service = Arc::new(MakeProxyService::new(policy_engine, rate_limiter)); let listen_addr: std::net::SocketAddr = config.server.listen_addr.parse()?; let listener = TcpListener::bind(listen_addr).await?; diff --git a/src/proxy.rs b/src/proxy.rs index ef3beba..d111f22 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -15,12 +15,13 @@ use opentelemetry::global; use opentelemetry::propagation::Extractor; use rustls::ServerConnection; use rustls_pki_types::CertificateDer; -use tokio::io::copy_bidirectional; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tracing::{Instrument, error, info, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::policy::{self, PolicyDecision, PolicyEngine, RequestContext}; +use crate::rate_limit::{RateLimitDecision, RateLimiter}; type ProxyBody = BoxBody; @@ -43,6 +44,7 @@ fn extract_trace_context(headers: &HeaderMap) -> opentelemetry::Context { #[derive(Clone)] pub struct ProxyService { policy_engine: Arc, + rate_limiter: Arc, peer_certs: Vec>, source_peer_addr: SocketAddr, } @@ -50,11 +52,13 @@ pub struct ProxyService { impl ProxyService { fn new( policy_engine: Arc, + rate_limiter: Arc, peer_certs: Vec>, source_peer_addr: SocketAddr, ) -> Self { Self { policy_engine, + rate_limiter, peer_certs, source_peer_addr, } @@ -132,6 +136,7 @@ impl ProxyService { source_identity, self.source_peer_addr, dest.authority, + self.rate_limiter.clone(), ); response(StatusCode::OK, "") @@ -154,11 +159,15 @@ impl Service> for ProxyService { pub struct MakeProxyService { policy_engine: Arc, + rate_limiter: Arc, } impl MakeProxyService { - pub fn new(policy_engine: Arc) -> Self { - Self { policy_engine } + pub fn new(policy_engine: Arc, rate_limiter: Arc) -> Self { + Self { + policy_engine, + rate_limiter, + } } #[must_use] @@ -167,7 +176,12 @@ impl MakeProxyService { peer_certs: Vec>, source_peer_addr: SocketAddr, ) -> ProxyService { - ProxyService::new(self.policy_engine.clone(), peer_certs, source_peer_addr) + ProxyService::new( + self.policy_engine.clone(), + self.rate_limiter.clone(), + peer_certs, + source_peer_addr, + ) } } @@ -210,6 +224,7 @@ fn spawn_tunnel( source_identity: String, source_peer_addr: SocketAddr, dest_authority: String, + rate_limiter: Arc, ) { let tunnel_span = tracing::Span::current(); tokio::spawn( @@ -230,16 +245,38 @@ fn spawn_tunnel( let mut downstream = hyper_util::rt::TokioIo::new(upgraded); - match copy_bidirectional(&mut downstream, &mut upstream).await { - Ok((up, down)) => { - info!( - source_identity = %source_identity, - source_peer_addr = %source_peer_addr, - dest_authority = %dest_authority, - bytes_client_to_dest = up, - bytes_dest_to_client = down, - "tunnel closed" - ); + match copy_bidirectional_with_rate_limit( + &mut downstream, + &mut upstream, + &source_identity, + &rate_limiter, + ) + .await + { + Ok(TunnelStats { + bytes_client_to_dest, + bytes_dest_to_client, + limited, + }) => { + if limited { + warn!( + source_identity = %source_identity, + source_peer_addr = %source_peer_addr, + dest_authority = %dest_authority, + bytes_client_to_dest = bytes_client_to_dest, + bytes_dest_to_client = bytes_dest_to_client, + "tunnel closed after rate limit" + ); + } else { + info!( + source_identity = %source_identity, + source_peer_addr = %source_peer_addr, + dest_authority = %dest_authority, + bytes_client_to_dest = bytes_client_to_dest, + bytes_dest_to_client = bytes_dest_to_client, + "tunnel closed" + ); + } } Err(e) => { error!( @@ -256,6 +293,121 @@ fn spawn_tunnel( ); } +struct TunnelStats { + bytes_client_to_dest: u64, + bytes_dest_to_client: u64, + limited: bool, +} + +async fn copy_bidirectional_with_rate_limit( + downstream: &mut D, + upstream: &mut U, + source_identity: &str, + rate_limiter: &RateLimiter, +) -> std::io::Result +where + D: AsyncRead + AsyncWrite + Unpin, + U: AsyncRead + AsyncWrite + Unpin, +{ + let mut downstream_buf = vec![0_u8; 16 * 1024]; + let mut upstream_buf = vec![0_u8; 16 * 1024]; + + let mut bytes_client_to_dest = 0_u64; + let mut bytes_dest_to_client = 0_u64; + + loop { + tokio::select! { + read_result = downstream.read(&mut downstream_buf) => { + let n = read_result?; + + if n == 0 { + upstream.shutdown().await?; + return Ok(TunnelStats { + bytes_client_to_dest, + bytes_dest_to_client, + limited: false, + }); + } + + match rate_limiter.check_and_record(source_identity, n as u64) { + RateLimitDecision::Allowed => { + upstream.write_all(&downstream_buf[..n]).await?; + bytes_client_to_dest += n as u64; + } + RateLimitDecision::Limited { + bytes_used, + attempted_bytes, + max_bytes, + window_secs, + } => { + warn!( + source_identity = %source_identity, + bytes_used = bytes_used, + attempted_bytes = attempted_bytes, + max_bytes = max_bytes, + window_secs = window_secs, + "rate limit exceeded while proxying client-to-destination data" + ); + + let _ = upstream.shutdown().await; + let _ = downstream.shutdown().await; + + return Ok(TunnelStats { + bytes_client_to_dest, + bytes_dest_to_client, + limited: true, + }); + } + } + } + + read_result = upstream.read(&mut upstream_buf) => { + let n = read_result?; + + if n == 0 { + downstream.shutdown().await?; + return Ok(TunnelStats { + bytes_client_to_dest, + bytes_dest_to_client, + limited: false, + }); + } + + match rate_limiter.check_and_record(source_identity, n as u64) { + RateLimitDecision::Allowed => { + downstream.write_all(&upstream_buf[..n]).await?; + bytes_dest_to_client += n as u64; + } + RateLimitDecision::Limited { + bytes_used, + attempted_bytes, + max_bytes, + window_secs, + } => { + warn!( + source_identity = %source_identity, + bytes_used = bytes_used, + attempted_bytes = attempted_bytes, + max_bytes = max_bytes, + window_secs = window_secs, + "rate limit exceeded while proxying destination-to-client data" + ); + + let _ = upstream.shutdown().await; + let _ = downstream.shutdown().await; + + return Ok(TunnelStats { + bytes_client_to_dest, + bytes_dest_to_client, + limited: true, + }); + } + } + } + } + } +} + fn response(status: StatusCode, message: &str) -> Response { let body: ProxyBody = if message.is_empty() { Empty::::new().boxed() diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..cb58259 --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,173 @@ +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub enabled: bool, + pub window_secs: u64, + pub max_bytes_per_identity: u64, +} + +impl RateLimitConfig { + #[must_use] + pub const fn disabled() -> Self { + Self { + enabled: false, + window_secs: 60, + max_bytes_per_identity: 10 * 1024 * 1024, + } + } + + #[must_use] + pub fn window(&self) -> Duration { + Duration::from_secs(self.window_secs) + } +} + +#[derive(Debug)] +struct IdentityUsage { + window_start: Instant, + bytes_used: u64, +} + +#[derive(Debug)] +pub struct RateLimiter { + config: RateLimitConfig, + usage: Mutex>, +} + +impl RateLimiter { + #[must_use] + pub fn new(config: RateLimitConfig) -> Self { + Self { + config, + usage: Mutex::new(HashMap::new()), + } + } + + pub fn check_and_record(&self, identity: &str, bytes: u64) -> RateLimitDecision { + if !self.config.enabled { + return RateLimitDecision::Allowed; + } + + let now = Instant::now(); + let mut usage = self.usage.lock().expect("rate limiter mutex poisoned"); + + let entry = usage + .entry(identity.to_owned()) + .or_insert_with(|| IdentityUsage { + window_start: now, + bytes_used: 0, + }); + + if now.duration_since(entry.window_start) >= self.config.window() { + entry.window_start = now; + entry.bytes_used = 0; + } + + let next_total = entry.bytes_used.saturating_add(bytes); + + if next_total > self.config.max_bytes_per_identity { + return RateLimitDecision::Limited { + bytes_used: entry.bytes_used, + attempted_bytes: bytes, + max_bytes: self.config.max_bytes_per_identity, + window_secs: self.config.window_secs, + }; + } + + entry.bytes_used = next_total; + RateLimitDecision::Allowed + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RateLimitDecision { + Allowed, + Limited { + bytes_used: u64, + attempted_bytes: u64, + max_bytes: u64, + window_secs: u64, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allows_bytes_under_limit() { + let limiter = RateLimiter::new(RateLimitConfig { + enabled: true, + window_secs: 60, + max_bytes_per_identity: 100, + }); + + assert_eq!( + limiter.check_and_record("agent-alpha", 40), + RateLimitDecision::Allowed + ); + assert_eq!( + limiter.check_and_record("agent-alpha", 50), + RateLimitDecision::Allowed + ); + } + + #[test] + fn blocks_bytes_over_limit() { + let limiter = RateLimiter::new(RateLimitConfig { + enabled: true, + window_secs: 60, + max_bytes_per_identity: 100, + }); + + assert_eq!( + limiter.check_and_record("agent-alpha", 80), + RateLimitDecision::Allowed + ); + + match limiter.check_and_record("agent-alpha", 30) { + RateLimitDecision::Limited { max_bytes, .. } => assert_eq!(max_bytes, 100), + RateLimitDecision::Allowed => panic!("expected rate limit"), + } + } + + #[test] + fn keeps_identities_separate() { + let limiter = RateLimiter::new(RateLimitConfig { + enabled: true, + window_secs: 60, + max_bytes_per_identity: 100, + }); + + assert_eq!( + limiter.check_and_record("agent-alpha", 100), + RateLimitDecision::Allowed + ); + assert_eq!( + limiter.check_and_record("agent-beta", 100), + RateLimitDecision::Allowed + ); + + assert!(matches!( + limiter.check_and_record("agent-alpha", 1), + RateLimitDecision::Limited { .. } + )); + } + + #[test] + fn disabled_limiter_allows_without_tracking() { + let limiter = RateLimiter::new(RateLimitConfig { + enabled: false, + window_secs: 60, + max_bytes_per_identity: 1, + }); + + assert_eq!( + limiter.check_and_record("agent-alpha", 1_000_000), + RateLimitDecision::Allowed + ); + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4b33d8b..63e14c6 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -19,6 +19,7 @@ use x509_parser::prelude::*; use agent_gateway::policy::PolicyEngine; use agent_gateway::proxy::MakeProxyService; +use agent_gateway::rate_limit::{RateLimitConfig, RateLimiter}; const CLIENT_EXTENSION_OID: &[u64] = &[1, 3, 6, 1, 4, 1, 57264, 1, 1]; static TEST_ID: AtomicU64 = AtomicU64::new(1); @@ -711,7 +712,8 @@ pub async fn start_proxy( let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); - let make_service = Arc::new(MakeProxyService::new(policy_engine)); + let rate_limiter = Arc::new(RateLimiter::new(RateLimitConfig::disabled())); + let make_service = Arc::new(MakeProxyService::new(policy_engine, rate_limiter)); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();