Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ max_connections = 5
connect_timeout_ms = 5000
pool_acquire_timeout_ms = 1000
query_timeout_ms = 500

# Optional per-identity data limit for proxied traffic.
# This is intentionally local to one gateway process.
[rate_limit]
enabled = false
window_secs = 60
max_bytes_per_identity = 10485760
44 changes: 44 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use std::time::Duration;
use serde::Deserialize;

use crate::policy;
use crate::rate_limit::RateLimitConfig;

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub server: ServerConfig,
pub observability: ObservabilityConfig,
pub policy: PolicyConfig,
pub rate_limit: Option<RateLimitSection>,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -40,6 +42,43 @@ pub struct PolicyConfig {
pub query_timeout_ms: Option<u64>,
}

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RateLimitSection {
pub enabled: Option<bool>,
pub window_secs: Option<u64>,
pub max_bytes_per_identity: Option<u64>,
}

impl RateLimitSection {
#[must_use]
pub fn to_runtime_config(&self) -> RateLimitConfig {
RateLimitConfig {
enabled: self.enabled.unwrap_or(false),
window_secs: self.window_secs.unwrap_or(60),
max_bytes_per_identity: self.max_bytes_per_identity.unwrap_or(10 * 1024 * 1024),
}
}

fn validate(&self) -> anyhow::Result<()> {
if let Some(window_secs) = self.window_secs {
anyhow::ensure!(
window_secs > 0,
"rate_limit.window_secs must be greater than zero"
);
}

if let Some(max_bytes) = self.max_bytes_per_identity {
anyhow::ensure!(
max_bytes > 0,
"rate_limit.max_bytes_per_identity must be greater than zero"
);
}

Ok(())
}
}

impl Config {
/// Load, parse, and validate a TOML config file.
///
Expand All @@ -62,6 +101,11 @@ impl Config {

policy::parse_client_ext_oid(&self.policy.client_ext_oid)?;
self.policy.validate()?;

if let Some(rate_limit) = &self.rate_limit {
rate_limit.validate()?;
}

Ok(())
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ pub mod config;
pub mod observability;
pub mod policy;
pub mod proxy;
pub mod rate_limit;
mod registry;
pub mod tls;
10 changes: 9 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::path::PathBuf;
use std::sync::Arc;

use agent_gateway::proxy::MakeProxyService;
use agent_gateway::rate_limit::{RateLimitConfig, RateLimiter};
use agent_gateway::{config, observability, policy, proxy, tls};
use anyhow::Context;
use clap::Parser;
Expand Down Expand Up @@ -37,7 +38,14 @@ async fn serve(config: config::Config) -> anyhow::Result<()> {
let tls_acceptor = tls::TlsAcceptor::from(server_tls);

let policy_engine = policy::build_engine(&config.policy).await?;
let make_service = Arc::new(MakeProxyService::new(policy_engine));
let rate_limit_config = config
.rate_limit
.as_ref()
.map_or_else(RateLimitConfig::disabled, |section| {
section.to_runtime_config()
});
let rate_limiter = Arc::new(RateLimiter::new(rate_limit_config));
let make_service = Arc::new(MakeProxyService::new(policy_engine, rate_limiter));

let listen_addr: std::net::SocketAddr = config.server.listen_addr.parse()?;
let listener = TcpListener::bind(listen_addr).await?;
Expand Down
180 changes: 166 additions & 14 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ use opentelemetry::global;
use opentelemetry::propagation::Extractor;
use rustls::ServerConnection;
use rustls_pki_types::CertificateDer;
use tokio::io::copy_bidirectional;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{Instrument, error, info, warn};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::policy::{self, PolicyDecision, PolicyEngine, RequestContext};
use crate::rate_limit::{RateLimitDecision, RateLimiter};

type ProxyBody = BoxBody<Bytes, Infallible>;

Expand All @@ -43,18 +44,21 @@ fn extract_trace_context(headers: &HeaderMap) -> opentelemetry::Context {
#[derive(Clone)]
pub struct ProxyService {
policy_engine: Arc<dyn PolicyEngine>,
rate_limiter: Arc<RateLimiter>,
peer_certs: Vec<CertificateDer<'static>>,
source_peer_addr: SocketAddr,
}

impl ProxyService {
fn new(
policy_engine: Arc<dyn PolicyEngine>,
rate_limiter: Arc<RateLimiter>,
peer_certs: Vec<CertificateDer<'static>>,
source_peer_addr: SocketAddr,
) -> Self {
Self {
policy_engine,
rate_limiter,
peer_certs,
source_peer_addr,
}
Expand Down Expand Up @@ -132,6 +136,7 @@ impl ProxyService {
source_identity,
self.source_peer_addr,
dest.authority,
self.rate_limiter.clone(),
);

response(StatusCode::OK, "")
Expand All @@ -154,11 +159,15 @@ impl Service<Request<Incoming>> for ProxyService {

pub struct MakeProxyService {
policy_engine: Arc<dyn PolicyEngine>,
rate_limiter: Arc<RateLimiter>,
}

impl MakeProxyService {
pub fn new(policy_engine: Arc<dyn PolicyEngine>) -> Self {
Self { policy_engine }
pub fn new(policy_engine: Arc<dyn PolicyEngine>, rate_limiter: Arc<RateLimiter>) -> Self {
Self {
policy_engine,
rate_limiter,
}
}

#[must_use]
Expand All @@ -167,7 +176,12 @@ impl MakeProxyService {
peer_certs: Vec<CertificateDer<'static>>,
source_peer_addr: SocketAddr,
) -> ProxyService {
ProxyService::new(self.policy_engine.clone(), peer_certs, source_peer_addr)
ProxyService::new(
self.policy_engine.clone(),
self.rate_limiter.clone(),
peer_certs,
source_peer_addr,
)
}
}

Expand Down Expand Up @@ -210,6 +224,7 @@ fn spawn_tunnel(
source_identity: String,
source_peer_addr: SocketAddr,
dest_authority: String,
rate_limiter: Arc<RateLimiter>,
) {
let tunnel_span = tracing::Span::current();
tokio::spawn(
Expand All @@ -230,16 +245,38 @@ fn spawn_tunnel(

let mut downstream = hyper_util::rt::TokioIo::new(upgraded);

match copy_bidirectional(&mut downstream, &mut upstream).await {
Ok((up, down)) => {
info!(
source_identity = %source_identity,
source_peer_addr = %source_peer_addr,
dest_authority = %dest_authority,
bytes_client_to_dest = up,
bytes_dest_to_client = down,
"tunnel closed"
);
match copy_bidirectional_with_rate_limit(
&mut downstream,
&mut upstream,
&source_identity,
&rate_limiter,
)
.await
{
Ok(TunnelStats {
bytes_client_to_dest,
bytes_dest_to_client,
limited,
}) => {
if limited {
warn!(
source_identity = %source_identity,
source_peer_addr = %source_peer_addr,
dest_authority = %dest_authority,
bytes_client_to_dest = bytes_client_to_dest,
bytes_dest_to_client = bytes_dest_to_client,
"tunnel closed after rate limit"
);
} else {
info!(
source_identity = %source_identity,
source_peer_addr = %source_peer_addr,
dest_authority = %dest_authority,
bytes_client_to_dest = bytes_client_to_dest,
bytes_dest_to_client = bytes_dest_to_client,
"tunnel closed"
);
}
}
Err(e) => {
error!(
Expand All @@ -256,6 +293,121 @@ fn spawn_tunnel(
);
}

struct TunnelStats {
bytes_client_to_dest: u64,
bytes_dest_to_client: u64,
limited: bool,
}

async fn copy_bidirectional_with_rate_limit<D, U>(
downstream: &mut D,
upstream: &mut U,
source_identity: &str,
rate_limiter: &RateLimiter,
) -> std::io::Result<TunnelStats>
where
D: AsyncRead + AsyncWrite + Unpin,
U: AsyncRead + AsyncWrite + Unpin,
{
let mut downstream_buf = vec![0_u8; 16 * 1024];
let mut upstream_buf = vec![0_u8; 16 * 1024];

let mut bytes_client_to_dest = 0_u64;
let mut bytes_dest_to_client = 0_u64;

loop {
tokio::select! {
read_result = downstream.read(&mut downstream_buf) => {
let n = read_result?;

if n == 0 {
upstream.shutdown().await?;
return Ok(TunnelStats {
bytes_client_to_dest,
bytes_dest_to_client,
limited: false,
});
}

match rate_limiter.check_and_record(source_identity, n as u64) {
RateLimitDecision::Allowed => {
upstream.write_all(&downstream_buf[..n]).await?;
bytes_client_to_dest += n as u64;
}
RateLimitDecision::Limited {
bytes_used,
attempted_bytes,
max_bytes,
window_secs,
} => {
warn!(
source_identity = %source_identity,
bytes_used = bytes_used,
attempted_bytes = attempted_bytes,
max_bytes = max_bytes,
window_secs = window_secs,
"rate limit exceeded while proxying client-to-destination data"
);

let _ = upstream.shutdown().await;
let _ = downstream.shutdown().await;

return Ok(TunnelStats {
bytes_client_to_dest,
bytes_dest_to_client,
limited: true,
});
}
}
}

read_result = upstream.read(&mut upstream_buf) => {
let n = read_result?;

if n == 0 {
downstream.shutdown().await?;
return Ok(TunnelStats {
bytes_client_to_dest,
bytes_dest_to_client,
limited: false,
});
}

match rate_limiter.check_and_record(source_identity, n as u64) {
RateLimitDecision::Allowed => {
downstream.write_all(&upstream_buf[..n]).await?;
bytes_dest_to_client += n as u64;
}
RateLimitDecision::Limited {
bytes_used,
attempted_bytes,
max_bytes,
window_secs,
} => {
warn!(
source_identity = %source_identity,
bytes_used = bytes_used,
attempted_bytes = attempted_bytes,
max_bytes = max_bytes,
window_secs = window_secs,
"rate limit exceeded while proxying destination-to-client data"
);

let _ = upstream.shutdown().await;
let _ = downstream.shutdown().await;

return Ok(TunnelStats {
bytes_client_to_dest,
bytes_dest_to_client,
limited: true,
});
}
}
}
}
}
}

fn response(status: StatusCode, message: &str) -> Response<ProxyBody> {
let body: ProxyBody = if message.is_empty() {
Empty::<Bytes>::new().boxed()
Expand Down
Loading