diff --git a/.github/workflows/fixture-freshness.yml b/.github/workflows/fixture-freshness.yml index 3bc0b9e7..ceb642a9 100644 --- a/.github/workflows/fixture-freshness.yml +++ b/.github/workflows/fixture-freshness.yml @@ -54,7 +54,7 @@ jobs: # Cache only the standalone tool's target dir; the parent # workspace's cache is owned by other workflows. - - uses: Swatinem/rust-cache@v2.7.3 + - uses: Swatinem/rust-cache@v2 with: shared-key: fixture-freshness workspaces: test/fixtures/refresh-tool diff --git a/.github/workflows/licensing-conformance.yml b/.github/workflows/licensing-conformance.yml index 9b5257b3..df526b9d 100644 --- a/.github/workflows/licensing-conformance.yml +++ b/.github/workflows/licensing-conformance.yml @@ -176,7 +176,7 @@ jobs: - name: open PR with refreshed vendored copies if: steps.diff.outputs.changed == '1' - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: token: ${{ secrets.GITHUB_TOKEN }} branch: chore/licensing-schemas-refresh diff --git a/.github/workflows/perf-regression.yml b/.github/workflows/perf-regression.yml index 6bb72784..c59129db 100644 --- a/.github/workflows/perf-regression.yml +++ b/.github/workflows/perf-regression.yml @@ -93,13 +93,13 @@ jobs: # would let the second build trample the first's lockfile state # and skew the bench. - name: cache (candidate) - uses: Swatinem/rust-cache@v2.7.3 + uses: Swatinem/rust-cache@v2 with: shared-key: perf-regression-candidate workspaces: candidate-src - name: cache (baseline) - uses: Swatinem/rust-cache@v2.7.3 + uses: Swatinem/rust-cache@v2 with: shared-key: perf-regression-baseline workspaces: baseline-src diff --git a/.github/workflows/synthetic.yml b/.github/workflows/synthetic.yml index da1e97f7..8899447b 100644 --- a/.github/workflows/synthetic.yml +++ b/.github/workflows/synthetic.yml @@ -54,7 +54,7 @@ jobs: # `bench-synthetic` is its own workspace, so it gets a dedicated # cache key. That keeps the proxy workspace cache from being # invalidated by churn on the synthetic harness. - - uses: Swatinem/rust-cache@v2.7.3 + - uses: Swatinem/rust-cache@v2 with: shared-key: synthetic workspaces: bench-synthetic @@ -189,7 +189,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2.7.3 + - uses: Swatinem/rust-cache@v2 with: shared-key: chaos-hot-reload diff --git a/Cargo.lock b/Cargo.lock index f9cb5218..261b6977 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4581,9 +4581,11 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6e05acbfada5ec79023c85368af14abd0b307c015e9064d249b2a950ef459a6" dependencies = [ + "hex", "opentelemetry", "opentelemetry_sdk", "prost 0.13.5", + "serde", "tonic", ] @@ -6255,6 +6257,7 @@ dependencies = [ "regex", "reqwest", "sbproxy-cache", + "sbproxy-classifier-client", "sbproxy-config", "sbproxy-platform", "sbproxy-plugin", @@ -6467,11 +6470,13 @@ dependencies = [ "hmac", "http 1.4.0", "jsonwebtoken", + "opentelemetry-proto", "prost 0.13.5", "rand 0.8.6", "rcgen", "reqwest", "rustls", + "sbproxy-ai", "sbproxy-cache", "sbproxy-config", "sbproxy-middleware", @@ -6482,6 +6487,7 @@ dependencies = [ "sha2 0.11.0", "tempfile", "tokio", + "tokio-stream", "tokio-tungstenite", "tonic", "tonic-build", diff --git a/crates/sbproxy-ai/Cargo.toml b/crates/sbproxy-ai/Cargo.toml index 9951d300..213086d4 100644 --- a/crates/sbproxy-ai/Cargo.toml +++ b/crates/sbproxy-ai/Cargo.toml @@ -11,6 +11,10 @@ sbproxy-config.workspace = true sbproxy-cache.workspace = true sbproxy-security.workspace = true sbproxy-platform.workspace = true +# WOR-1223: local embedding source for the semantic cache. Only the +# sidecar client is used here; the in-process embedder lives in sbproxy-core +# because sbproxy-classifiers depends on sbproxy-ai (would be a cycle). +sbproxy-classifier-client.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/sbproxy-ai/src/handler.rs b/crates/sbproxy-ai/src/handler.rs index 7d84b7f1..af2b1c31 100644 --- a/crates/sbproxy-ai/src/handler.rs +++ b/crates/sbproxy-ai/src/handler.rs @@ -71,6 +71,15 @@ pub struct AiHandlerConfig { /// See `sbproxy_security::pii::PiiConfig` for the rule schema. #[serde(default)] pub pii: Option, + /// WOR-1228: when `true`, emit the prompt text as the OpenInference + /// `input.value` span attribute so trace backends (Phoenix, Langfuse) + /// show the actual conversation, not just token counts. Off by default + /// because prompt content is sensitive: when on, the text is routed + /// through the configured `pii` redactor (if any) and the always-on + /// secret redactor before it lands on the span. Enable only with `pii` + /// configured and a trace backend inside your trust boundary. + #[serde(default)] + pub trace_content: bool, /// Opaque semantic-cache configuration block. The OSS proxy /// stores this verbatim and surfaces it through the stream cache /// recorder hook so the enterprise implementation can read its @@ -857,6 +866,7 @@ mod tests { resilience: None, shadow: None, pii: None, + trace_content: false, semantic_cache: None, prompts: None, usage_parser: "auto".to_string(), @@ -885,6 +895,7 @@ mod tests { resilience: None, shadow: None, pii: None, + trace_content: false, semantic_cache: None, prompts: None, usage_parser: "auto".to_string(), @@ -913,6 +924,7 @@ mod tests { resilience: None, shadow: None, pii: None, + trace_content: false, semantic_cache: None, prompts: None, usage_parser: "auto".to_string(), @@ -942,6 +954,7 @@ mod tests { resilience: None, shadow: None, pii: None, + trace_content: false, semantic_cache: None, prompts: None, usage_parser: "auto".to_string(), diff --git a/crates/sbproxy-ai/src/semantic_cache.rs b/crates/sbproxy-ai/src/semantic_cache.rs index a5fbf7de..5859e808 100644 --- a/crates/sbproxy-ai/src/semantic_cache.rs +++ b/crates/sbproxy-ai/src/semantic_cache.rs @@ -110,9 +110,35 @@ pub struct EmbeddingCacheConfig { /// Maximum cached entries (LRU eviction). Defaults to 1024. #[serde(default = "default_max_entries")] pub max_entries: usize, - /// Embedding provider + model used to vectorize prompts. + /// Where prompt embeddings come from. Defaults to `provider` so + /// existing configs are unchanged. + #[serde(default)] + pub source: EmbeddingSource, + /// Embedding provider + model used to vectorize prompts (for + /// `source: provider`). #[serde(default)] pub embedding: Option, + /// Local classifier-sidecar embedding endpoint (for `source: sidecar`). + #[serde(default)] + pub sidecar: Option, + /// In-process embedder (for `source: inprocess`). + #[serde(default)] + pub inprocess: Option, +} + +/// Where the semantic cache gets prompt embeddings. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EmbeddingSource { + /// Call an AI embedding provider's `/v1/embeddings` API (default, + /// back-compat). Costs money and egresses the prompt. + #[default] + Provider, + /// Call the local classifier sidecar's `Embed` RPC. Free, no egress. + Sidecar, + /// Run an in-process tract embedder. Single-binary, but loads a model + /// into the proxy address space (opt-in). + Inprocess, } /// Which provider + model computes prompt embeddings. @@ -124,6 +150,44 @@ pub struct EmbeddingProviderConfig { pub model: String, } +/// Sidecar embedding endpoint config (for `source: sidecar`). +#[derive(Debug, Clone, serde::Deserialize)] +pub struct SidecarEmbeddingConfig { + /// gRPC endpoint, e.g. `http://127.0.0.1:9440`. + pub endpoint: String, + /// Embedding model id (empty selects the sidecar default). + #[serde(default)] + pub model: String, + /// Per-call timeout in milliseconds. Defaults to 500. + #[serde(default = "default_sidecar_timeout_ms")] + pub timeout_ms: u64, +} + +fn default_sidecar_timeout_ms() -> u64 { + 500 +} + +/// In-process embedder config (for `source: inprocess`). +/// +/// Provide explicit `model_path` + `tokenizer_path`. Known-model +/// auto-download for the in-process embedder is a follow-up; until then +/// the operator points at on-disk ONNX + tokenizer files. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct InprocessEmbeddingConfig { + /// Logical model id (informational; e.g. `all-MiniLM-L6-v2`). + #[serde(default)] + pub model: String, + /// Path to the ONNX model file. + #[serde(default)] + pub model_path: Option, + /// Path to the tokenizer.json file. + #[serde(default)] + pub tokenizer_path: Option, + /// Max model size in bytes (guard). None uses the engine default. + #[serde(default)] + pub max_model_bytes: Option, +} + fn default_threshold() -> f32 { 0.85 } @@ -170,8 +234,16 @@ struct EmbeddingEntry { pub struct EmbeddingCache { threshold: f32, ttl_secs: u64, + /// Where prompt embeddings come from. + source: EmbeddingSource, + /// Embedding provider name (for `source: provider`; empty otherwise). provider: String, + /// Embedding model id (for `source: provider`; empty otherwise). model: String, + /// Sidecar endpoint config (for `source: sidecar`). + sidecar: Option, + /// In-process embedder config (for `source: inprocess`). + inprocess: Option, entries: Mutex>, } @@ -193,22 +265,57 @@ impl EmbeddingCache { if !cfg.enabled { return None; } - let embedding = cfg.embedding.as_ref()?; + // Each source needs its own config block to be usable. A missing + // block means there is nothing to vectorize with, so the cache + // stays inert (None) rather than half-built. + let (provider, model, sidecar, inprocess) = match cfg.source { + EmbeddingSource::Provider => { + let e = cfg.embedding.as_ref()?; + (e.provider.clone(), e.model.clone(), None, None) + } + EmbeddingSource::Sidecar => { + let s = cfg.sidecar.as_ref()?; + (String::new(), s.model.clone(), Some(s.clone()), None) + } + EmbeddingSource::Inprocess => { + // The embedder is built and held by sbproxy-core (which can + // depend on the tract engine without a dependency cycle). The + // cache carries the config so core can load it. Require the + // block so a typo'd config does not silently fall back. + let p = cfg.inprocess.as_ref()?; + (String::new(), p.model.clone(), None, Some(p.clone())) + } + }; let cap = NonZeroUsize::new(cfg.max_entries.max(1)).expect("max_entries clamped to >= 1"); Some(Self { threshold: cfg.threshold, ttl_secs: cfg.ttl_secs, - provider: embedding.provider.clone(), - model: embedding.model.clone(), + source: cfg.source, + provider, + model, + sidecar, + inprocess, entries: Mutex::new(LruCache::new(cap)), }) } - /// Embedding provider name to vectorize prompts with. + /// Where this cache gets prompt embeddings. + pub fn source(&self) -> EmbeddingSource { + self.source + } + /// Sidecar endpoint config, when `source` is `sidecar`. + pub fn sidecar_config(&self) -> Option<&SidecarEmbeddingConfig> { + self.sidecar.as_ref() + } + /// In-process embedder config, when `source` is `inprocess`. + pub fn inprocess_config(&self) -> Option<&InprocessEmbeddingConfig> { + self.inprocess.as_ref() + } + /// Embedding provider name to vectorize prompts with (provider source). pub fn provider(&self) -> &str { &self.provider } - /// Embedding model id. + /// Embedding model id (provider source). pub fn model(&self) -> &str { &self.model } @@ -318,6 +425,26 @@ impl EmbeddingCache { } } +/// Compute an embedding via the local classifier sidecar's `Embed` RPC. +/// +/// Used when `source: sidecar`. No provider API call, no prompt egress. +pub async fn compute_embedding_sidecar( + cfg: &SidecarEmbeddingConfig, + text: &str, +) -> anyhow::Result> { + let client = sbproxy_classifier_client::ClassifierClient::connect_lazy( + &cfg.endpoint, + std::time::Duration::from_millis(cfg.timeout_ms), + ) + .map_err(|e| anyhow::anyhow!("sidecar connect: {e}"))?; + let mut out = client + .embed(&cfg.model, &[text.to_string()]) + .await + .map_err(|e| anyhow::anyhow!("sidecar embed: {e}"))?; + out.pop() + .ok_or_else(|| anyhow::anyhow!("sidecar returned no embedding")) +} + /// Compute an embedding vector for `text` by POSTing `/v1/embeddings` /// to `provider` with `model` (WOR-796). Used by the dispatcher to /// vectorize a prompt for the semantic-cache lookup. Returns the first @@ -437,10 +564,13 @@ mod tests { threshold, ttl_secs: ttl, max_entries: max, + source: EmbeddingSource::Provider, embedding: Some(EmbeddingProviderConfig { provider: "openai".to_string(), model: "text-embedding-3-small".to_string(), }), + sidecar: None, + inprocess: None, }) .expect("enabled config builds") } @@ -460,10 +590,13 @@ mod tests { threshold: 0.85, ttl_secs: 60, max_entries: 8, + source: EmbeddingSource::Provider, embedding: Some(EmbeddingProviderConfig { provider: "openai".to_string(), model: "m".to_string(), }), + sidecar: None, + inprocess: None, }; assert!(EmbeddingCache::from_config(&cfg).is_none()); } @@ -475,7 +608,10 @@ mod tests { threshold: 0.85, ttl_secs: 60, max_entries: 8, + source: EmbeddingSource::Provider, embedding: None, + sidecar: None, + inprocess: None, }; assert!(EmbeddingCache::from_config(&cfg).is_none()); } @@ -590,6 +726,46 @@ mod tests { let cache = EmbeddingCache::from_config(&cfg).unwrap(); assert_eq!(cache.provider(), "openai"); assert_eq!(cache.model(), "text-embedding-3-small"); + // Default source is provider so existing configs are unchanged. + assert_eq!(cache.source(), EmbeddingSource::Provider); + } + + #[test] + fn source_defaults_to_provider() { + let cfg: EmbeddingCacheConfig = serde_json::from_value(serde_json::json!({ + "enabled": true, + "embedding": { "provider": "openai", "model": "text-embedding-3-small" } + })) + .unwrap(); + assert_eq!(cfg.source, EmbeddingSource::Provider); + } + + #[test] + fn sidecar_source_parses_and_builds() { + let cfg: EmbeddingCacheConfig = serde_json::from_value(serde_json::json!({ + "enabled": true, + "source": "sidecar", + "sidecar": { "endpoint": "http://127.0.0.1:9440", "model": "all-MiniLM-L6-v2", "timeout_ms": 750 } + })) + .unwrap(); + assert_eq!(cfg.source, EmbeddingSource::Sidecar); + let cache = + EmbeddingCache::from_config(&cfg).expect("sidecar cache builds without a provider"); + assert_eq!(cache.source(), EmbeddingSource::Sidecar); + let sc = cache.sidecar_config().expect("sidecar config present"); + assert_eq!(sc.endpoint, "http://127.0.0.1:9440"); + assert_eq!(sc.timeout_ms, 750); + } + + #[test] + fn sidecar_source_without_block_is_inert() { + let cfg: EmbeddingCacheConfig = serde_json::from_value(serde_json::json!({ + "enabled": true, + "source": "sidecar" + })) + .unwrap(); + // No sidecar block: nothing to vectorize with, so the cache stays inert. + assert!(EmbeddingCache::from_config(&cfg).is_none()); } #[test] diff --git a/crates/sbproxy-ai/src/tracing_spans.rs b/crates/sbproxy-ai/src/tracing_spans.rs index 4adee112..3e40a39f 100644 --- a/crates/sbproxy-ai/src/tracing_spans.rs +++ b/crates/sbproxy-ai/src/tracing_spans.rs @@ -115,6 +115,26 @@ pub fn ai_request_span(surface: &str, method: &str) -> Span { // original token snapshot; without it, the spend record // is not reproducible past a pricing-table edit. "sbproxy.ai.pricing_version" = Empty, + // WOR-1229: derived USD cost for the request, so trace backends + // (Phoenix, Langfuse, Tempo) show spend per generation alongside + // tokens. Recorded at the billing choke point via + // `record_cost_usd`. Both the OpenInference and gen_ai keys are + // stamped so either backend vocabulary renders it. + "gen_ai.usage.cost" = Empty, + "llm.usage.total_cost" = Empty, + // WOR-1231: error semantics. `otel.status_code` is the field the + // tracing-opentelemetry bridge maps to the OTel span status, so a + // failed generation surfaces as an ERROR span in trace backends. + // `error.type` carries the failure class (gen_ai / OTel convention). + "otel.status_code" = Empty, + "otel.status_message" = Empty, + "error.type" = Empty, + // WOR-1228: OpenInference prompt / completion content. Empty unless + // the origin sets `trace_content: true`; the dispatch path redacts + // the text (secrets + PII) before recording it here, so backends can + // show the conversation, not just token counts. + "input.value" = Empty, + "output.value" = Empty, "gen_ai.response.finish_reasons" = Empty, "llm.provider" = Empty, "llm.model_name" = Empty, @@ -254,6 +274,60 @@ pub fn record_pricing_version(span: &Span, version: &str) { span.record("sbproxy.ai.pricing_version", version); } +/// Stamp the derived USD cost of the request onto an AI span (WOR-1229). +/// +/// Records `gen_ai.usage.cost` (gen_ai vocabulary) and +/// `llm.usage.total_cost` (OpenInference vocabulary) so both trace-backend +/// conventions render spend per generation. `cost_usd` is the same value +/// the FinOps cost metric uses, derived from the token counts and the +/// pricing catalog stamped via [`record_pricing_version`]. +pub fn record_cost_usd(span: &Span, cost_usd: f64) { + span.record("gen_ai.usage.cost", cost_usd); + span.record("llm.usage.total_cost", cost_usd); +} + +/// Well-known failure classes for an AI generation, recorded as the OTel +/// `error.type` attribute (WOR-1231). Kept as string constants so call +/// sites and trace queries agree. +pub mod error_type { + /// A guardrail (input or output) blocked the request. + pub const GUARDRAIL_BLOCKED: &str = "guardrail_blocked"; + /// The provider returned HTTP 429 (rate limited). + pub const RATE_LIMITED: &str = "rate_limited"; + /// The provider returned a 5xx server error. + pub const PROVIDER_ERROR: &str = "provider_error"; + /// The provider's content filter rejected the request or response. + pub const CONTENT_FILTER: &str = "content_filter"; +} + +/// Mark an AI span as failed (WOR-1231). +/// +/// Sets `otel.status_code = "ERROR"` (which the tracing-opentelemetry bridge +/// maps to the OTel span status, so the span shows as an error in trace +/// backends), records the failure class as `error.type`, and stores a short +/// human-readable message. Use the [`error_type`] constants for `kind`. +pub fn record_error(span: &Span, kind: &str, message: &str) { + span.record("otel.status_code", "ERROR"); + span.record("error.type", kind); + span.record("otel.status_message", message); +} + +/// Record the prompt text as the OpenInference `input.value` span attribute +/// (WOR-1228). The caller MUST have already redacted the content and gated +/// on the origin's `trace_content` flag: this helper only writes the field, +/// it does not redact. Off-by-default content capture lives in the dispatch +/// path, which routes the text through the secret + PII redactors first. +pub fn record_input_content(span: &Span, redacted: &str) { + span.record("input.value", redacted); +} + +/// Record the completion text as the OpenInference `output.value` span +/// attribute (WOR-1228). Same contract as [`record_input_content`]: the +/// caller redacts and gates; this only writes the field. +pub fn record_output_content(span: &Span, redacted: &str) { + span.record("output.value", redacted); +} + /// Stamp the response model and identifier onto an AI span. /// /// `model` becomes `gen_ai.response.model`; `response_id` becomes @@ -572,6 +646,102 @@ mod tests { assert_field(span, "sbproxy.ai.pricing_version", "catalog-2026-06-01"); } + /// WOR-1229: derived USD cost lands on both the gen_ai and + /// OpenInference cost keys so either trace backend renders spend. + #[test] + fn record_cost_usd_stamps_both_vocabularies() { + use tracing_subscriber::prelude::*; + let layer = CaptureLayer::default(); + let subscriber = tracing_subscriber::registry().with(layer.clone()); + tracing::subscriber::with_default(subscriber, || { + let span = ai_request_span("chat", "POST"); + record_cost_usd(&span, 0.001234); + }); + let spans = snapshot_spans(&layer); + let span = find_span(&spans, "ai.request"); + assert_field(span, "gen_ai.usage.cost", "0.001234"); + assert_field(span, "llm.usage.total_cost", "0.001234"); + } + + /// WOR-1231: a failed generation marks the span ERROR with an + /// `error.type` so trace backends surface it as a failure. + #[test] + fn record_error_marks_span_failed() { + use tracing_subscriber::prelude::*; + let layer = CaptureLayer::default(); + let subscriber = tracing_subscriber::registry().with(layer.clone()); + tracing::subscriber::with_default(subscriber, || { + let span = ai_request_span("chat", "POST"); + record_error( + &span, + error_type::GUARDRAIL_BLOCKED, + "blocked by input guardrail", + ); + }); + let spans = snapshot_spans(&layer); + let span = find_span(&spans, "ai.request"); + assert_field(span, "otel.status_code", "ERROR"); + assert_field(span, "error.type", "guardrail_blocked"); + } + + /// WOR-1228: prompt / completion content lands on the OpenInference + /// `input.value` / `output.value` span attributes (already redacted by + /// the caller). + #[test] + fn record_content_stamps_input_and_output_values() { + use tracing_subscriber::prelude::*; + let layer = CaptureLayer::default(); + let subscriber = tracing_subscriber::registry().with(layer.clone()); + tracing::subscriber::with_default(subscriber, || { + let span = ai_request_span("chat", "POST"); + record_input_content(&span, "summarize this [redacted]"); + record_output_content(&span, "here is the summary"); + }); + let spans = snapshot_spans(&layer); + let span = find_span(&spans, "ai.request"); + assert_field(span, "input.value", "summarize this [redacted]"); + assert_field(span, "output.value", "here is the summary"); + } + + /// WOR-1232: GenAI semantic-convention conformance. Pin the version and + /// the required `gen_ai.*` attribute set so a span never silently drifts + /// off-spec. Recording into a field that `ai_request_span` does not + /// declare is a no-op, so dropping a required attribute fails this test. + #[test] + fn ai_request_span_conforms_to_pinned_genai_semconv() { + // Bump deliberately when re-validating against a newer semconv. + const GEN_AI_SEMCONV_VERSION: &str = "1.36.0"; + const REQUIRED_GEN_AI_FIELDS: &[&str] = &[ + "gen_ai.system", + "gen_ai.request.model", + "gen_ai.response.model", + "gen_ai.response.id", + "gen_ai.usage.input_tokens", + "gen_ai.usage.output_tokens", + "gen_ai.usage.cost", + "gen_ai.response.finish_reasons", + ]; + assert!( + !GEN_AI_SEMCONV_VERSION.is_empty(), + "semconv version must be pinned" + ); + + use tracing_subscriber::prelude::*; + let layer = CaptureLayer::default(); + let subscriber = tracing_subscriber::registry().with(layer.clone()); + tracing::subscriber::with_default(subscriber, || { + let span = ai_request_span("chat", "POST"); + for field in REQUIRED_GEN_AI_FIELDS { + span.record(*field, "conformance-probe"); + } + }); + let spans = snapshot_spans(&layer); + let span = find_span(&spans, "ai.request"); + for field in REQUIRED_GEN_AI_FIELDS { + assert_field(span, field, "conformance-probe"); + } + } + /// `UsageTokens::total()` (WOR-1084) sums every dimension, /// not just prompt + completion. Pinned so a downstream /// dashboard's "total tokens" math stays consistent. diff --git a/crates/sbproxy-classifier-client/src/lib.rs b/crates/sbproxy-classifier-client/src/lib.rs index 2b412b1d..ab77aafa 100644 --- a/crates/sbproxy-classifier-client/src/lib.rs +++ b/crates/sbproxy-classifier-client/src/lib.rs @@ -21,7 +21,8 @@ use std::path::{Path, PathBuf}; use std::time::Duration; use sbproxy_classifier_proto::{ - ClassifyRequest, ClassifyResponse, InferenceServiceClient, VersionRequest, VersionResponse, + ClassifyRequest, ClassifyResponse, EmbedRequest, InferenceServiceClient, VersionRequest, + VersionResponse, }; use tokio::net::UnixStream; use tonic::transport::{Channel, Endpoint}; @@ -193,6 +194,33 @@ impl ClassifierClient { } } + /// Embed `inputs` with the named model (empty = the sidecar's default). + /// + /// Returns one L2-normalized vector per input, in request order. Used by + /// the AI gateway semantic cache to vectorize prompts locally instead of + /// via a paid embedding-provider API. + pub async fn embed( + &self, + model: &str, + inputs: &[String], + ) -> Result>, ClassifierClientError> { + let request = EmbedRequest { + model: model.to_string(), + texts: inputs.to_vec(), + }; + let mut client = self.inner.clone(); + match tokio::time::timeout(self.timeout, client.embed(request)).await { + Ok(Ok(resp)) => Ok(resp + .into_inner() + .embeddings + .into_iter() + .map(|e| e.values) + .collect()), + Ok(Err(status)) => Err(ClassifierClientError::Rpc(status.to_string())), + Err(_) => Err(ClassifierClientError::Timeout(self.timeout)), + } + } + /// Probe the sidecar's version + served model ids (startup capability check). pub async fn version(&self) -> Result { let mut client = self.inner.clone(); @@ -208,7 +236,7 @@ impl ClassifierClient { mod tests { use super::*; use sbproxy_classifier_proto::{ - EmbedRequest, EmbedResponse, InferenceService, InferenceServiceServer, Label, + EmbedRequest, EmbedResponse, Embedding, InferenceService, InferenceServiceServer, Label, ModelInfoRequest, ModelInfoResponse, }; use tonic::{Request, Response, Status}; @@ -234,9 +262,19 @@ mod tests { } async fn embed( &self, - _req: Request, + req: Request, ) -> Result, Status> { - Err(Status::unimplemented("stub")) + // Echo one fixed 2-dim vector per input so the client mapping is + // exercised without a real ONNX model. + let n = req.into_inner().texts.len(); + Ok(Response::new(EmbedResponse { + embeddings: (0..n) + .map(|_| Embedding { + values: vec![1.0, 0.0], + }) + .collect(), + latency_us: 1, + })) } async fn model_info( &self, @@ -294,6 +332,24 @@ mod tests { assert_eq!(version.models, vec!["stub".to_string()]); } + #[tokio::test] + async fn embed_round_trips_against_a_stub() { + let endpoint = spawn_stub().await; + let client = + ClassifierClient::connect(&endpoint, Duration::from_secs(2), Duration::from_secs(2)) + .await + .expect("connect"); + let vecs = client + .embed( + "all-MiniLM-L6-v2", + &["hello".to_string(), "world".to_string()], + ) + .await + .unwrap(); + assert_eq!(vecs.len(), 2); + assert_eq!(vecs[0], vec![1.0, 0.0]); + } + #[tokio::test] async fn connect_to_dead_endpoint_errors() { // Port 1 refuses immediately; connect must surface a Connect error diff --git a/crates/sbproxy-classifier-sidecar/src/main.rs b/crates/sbproxy-classifier-sidecar/src/main.rs index 2b7a3217..398f87f4 100644 --- a/crates/sbproxy-classifier-sidecar/src/main.rs +++ b/crates/sbproxy-classifier-sidecar/src/main.rs @@ -30,11 +30,11 @@ use std::sync::Arc; use anyhow::{Context, Result}; use clap::Parser; use sbproxy_classifier_proto::{ - ClassifyRequest, ClassifyResponse, EmbedRequest, EmbedResponse, InferenceService, + ClassifyRequest, ClassifyResponse, EmbedRequest, EmbedResponse, Embedding, InferenceService, InferenceServiceServer, Label, ModelInfoRequest, ModelInfoResponse, VersionRequest, VersionResponse, }; -use sbproxy_classifiers::OnnxClassifier; +use sbproxy_classifiers::{OnnxClassifier, OnnxEmbedder}; use tonic::transport::Server; use tonic::{Request, Response, Status}; @@ -42,8 +42,13 @@ use tonic::{Request, Response, Status}; /// tract ONNX classifiers keyed by logical model id. struct SidecarService { models: HashMap>, - /// Model used when a request leaves `model` empty. + /// Embedding models keyed by logical id, paired with the embedding + /// dimension learned at load time (for `ModelInfo`). + embedders: HashMap, u32)>, + /// Classifier used when a `Classify` request leaves `model` empty. default_model: Option, + /// Embedder used when an `Embed` request leaves `model` empty. + default_embed_model: Option, /// Reported by the `Version` RPC. version: String, } @@ -58,6 +63,17 @@ impl SidecarService { }; self.models.get(&id).map(|m| (id, Arc::clone(m))) } + + /// Resolve a request's `model` field (or the default) to a loaded + /// embedder. + fn resolve_embedder(&self, model: &str) -> Option<(String, Arc)> { + let id = if model.is_empty() { + self.default_embed_model.clone()? + } else { + model.to_string() + }; + self.embedders.get(&id).map(|(e, _)| (id, Arc::clone(e))) + } } #[tonic::async_trait] @@ -88,10 +104,34 @@ impl InferenceService for SidecarService { })) } - async fn embed(&self, _req: Request) -> Result, Status> { - Err(Status::unimplemented( - "embeddings are not supported by the minimal OSS classifier sidecar", - )) + async fn embed(&self, req: Request) -> Result, Status> { + let req = req.into_inner(); + let (_id, embedder) = self.resolve_embedder(&req.model).ok_or_else(|| { + Status::failed_precondition(format!( + "no embedding model loaded for {:?}; start the sidecar with --embed-model", + req.model + )) + })?; + let texts = req.texts; + let started = std::time::Instant::now(); + // tract inference is synchronous and CPU-bound: run it on the blocking + // pool so it never stalls a gRPC async worker. + let vectors = tokio::task::spawn_blocking(move || { + texts + .iter() + .map(|t| embedder.embed(t)) + .collect::>>() + }) + .await + .map_err(|e| Status::internal(format!("embed task panicked: {e}")))? + .map_err(|e| Status::internal(format!("embed failed: {e}")))?; + Ok(Response::new(EmbedResponse { + embeddings: vectors + .into_iter() + .map(|v| Embedding { values: v.values }) + .collect(), + latency_us: started.elapsed().as_micros() as u64, + })) } async fn model_info( @@ -99,19 +139,34 @@ impl InferenceService for SidecarService { req: Request, ) -> Result, Status> { let req = req.into_inner(); - let resp = match self.resolve(&req.model) { - Some((id, _)) => ModelInfoResponse { + // Classifiers first, then embedders (which report their dimension). + let resp = if let Some((id, _)) = self.resolve(&req.model) { + ModelInfoResponse { model: id, loaded: true, labels: Vec::new(), embedding_dim: 0, - }, - None => ModelInfoResponse { - model: req.model, - loaded: false, - labels: Vec::new(), - embedding_dim: 0, - }, + } + } else { + let embed_id = if req.model.is_empty() { + self.default_embed_model.clone() + } else { + Some(req.model.clone()) + }; + match embed_id.and_then(|id| self.embedders.get(&id).map(|(_, dim)| (id, *dim))) { + Some((id, dim)) => ModelInfoResponse { + model: id, + loaded: true, + labels: Vec::new(), + embedding_dim: dim, + }, + None => ModelInfoResponse { + model: req.model, + loaded: false, + labels: Vec::new(), + embedding_dim: 0, + }, + } }; Ok(Response::new(resp)) } @@ -154,6 +209,15 @@ struct Cli { /// single loaded model when exactly one is configured. #[arg(long)] default_model: Option, + /// Embedding model to load, as `id=:`. + /// Repeatable. Enables the `Embed` RPC (used by the AI gateway semantic + /// cache); without one, `Embed` returns FAILED_PRECONDITION. + #[arg(long = "embed-model", value_name = "ID=MODEL:TOKENIZER")] + embed_models: Vec, + /// Embedding model id used when an `Embed` request leaves `model` empty. + /// Defaults to the single loaded embedder when exactly one is configured. + #[arg(long)] + default_embed_model: Option, } /// Parse one `id=:` spec and load the classifier. @@ -169,6 +233,25 @@ fn load_model_spec(spec: &str) -> Result<(String, Arc)> { Ok((id.to_string(), Arc::new(classifier))) } +/// Parse one `id=:` spec and load the embedder, learning +/// its output dimension via a one-time warmup embed so `ModelInfo` can +/// report it. +fn load_embed_spec(spec: &str) -> Result<(String, Arc, u32)> { + let (id, paths) = spec + .split_once('=') + .with_context(|| format!("--embed-model must be ID=MODEL:TOKENIZER, got {spec:?}"))?; + let (model_path, tokenizer_path) = paths + .split_once(':') + .with_context(|| format!("--embed-model paths must be MODEL:TOKENIZER, got {paths:?}"))?; + let embedder = OnnxEmbedder::load(Path::new(model_path), Path::new(tokenizer_path)) + .with_context(|| format!("loading embed model {id:?}"))?; + let dim = embedder + .embed("dimension probe") + .map(|o| o.values.len() as u32) + .unwrap_or(0); + Ok((id.to_string(), Arc::new(embedder), dim)) +} + #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt() @@ -191,10 +274,26 @@ async fn main() -> Result<()> { } }); + let mut embedders = HashMap::new(); + for spec in &cli.embed_models { + let (id, embedder, dim) = load_embed_spec(spec)?; + embedders.insert(id, (embedder, dim)); + } + + let default_embed_model = cli.default_embed_model.or_else(|| { + if embedders.len() == 1 { + embedders.keys().next().cloned() + } else { + None + } + }); + let service = SidecarService { version: format!("sbproxy-classifier-sidecar {}", env!("CARGO_PKG_VERSION")), default_model, + default_embed_model, models, + embedders, }; if let Some(uds_path) = cli.listen_uds.as_ref() { @@ -246,7 +345,9 @@ mod tests { fn empty_service() -> SidecarService { SidecarService { models: HashMap::new(), + embedders: HashMap::new(), default_model: None, + default_embed_model: None, version: "sbproxy-classifier-sidecar test".to_string(), } } @@ -266,16 +367,35 @@ mod tests { } #[tokio::test] - async fn embed_is_unimplemented() { + async fn embed_without_model_is_failed_precondition() { let svc = empty_service(); let err = svc .embed(Request::new(EmbedRequest { model: String::new(), - texts: Vec::new(), + texts: vec!["hi".to_string()], + })) + .await + .expect_err("embed must fail when no embed model is loaded"); + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + } + + #[tokio::test] + async fn embed_unknown_model_is_failed_precondition() { + let svc = empty_service(); + let err = svc + .embed(Request::new(EmbedRequest { + model: "nope".to_string(), + texts: vec!["hi".to_string()], })) .await - .expect_err("embed must be unimplemented"); - assert_eq!(err.code(), tonic::Code::Unimplemented); + .expect_err("unknown embed model must fail"); + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + } + + #[test] + fn load_embed_spec_rejects_malformed() { + assert!(load_embed_spec("no-equals").is_err()); + assert!(load_embed_spec("id=only-one-path").is_err()); } #[tokio::test] diff --git a/crates/sbproxy-classifiers/src/embedder.rs b/crates/sbproxy-classifiers/src/embedder.rs new file mode 100644 index 00000000..0721a1bd --- /dev/null +++ b/crates/sbproxy-classifiers/src/embedder.rs @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +//! ONNX sentence-embedding support for the tract engine. +//! +//! Mirrors [`crate::OnnxClassifier`]'s loader, size-budget, and cache +//! conventions but produces an L2-normalized mean-pooled sentence vector +//! instead of a class label. Used by the OSS classifier sidecar's `Embed` +//! RPC and by the in-process embedding option for the AI gateway semantic +//! cache. Runs on pure-Rust `tract`, so it cross-compiles and air-gaps the +//! same way the classifier does. + +use std::path::Path; + +use anyhow::{anyhow, Context, Result}; +use tokenizers::Tokenizer; +use tract_onnx::prelude::*; + +use crate::{check_size_budget, LoadOptions, RunnableOnnxModel}; + +/// Mean-pool a `[seq_len, dim]` hidden-state matrix (row-major, flat) over +/// tokens, weighting each token by its attention mask. Tokens with mask 0 +/// are excluded (they are padding). Returns a `dim`-length vector. If every +/// token is masked out, returns an all-zero vector of length `dim`. +pub(crate) fn mean_pool(hidden: &[f32], mask: &[i64], seq_len: usize, dim: usize) -> Vec { + let mut acc = vec![0.0f32; dim]; + let mut count = 0.0f32; + for t in 0..seq_len { + if mask.get(t).copied().unwrap_or(0) == 0 { + continue; + } + count += 1.0; + let base = t * dim; + for d in 0..dim { + acc[d] += hidden[base + d]; + } + } + if count > 0.0 { + for v in &mut acc { + *v /= count; + } + } + acc +} + +/// L2-normalize a vector in place. A zero vector is left unchanged so the +/// caller never divides by zero. +pub(crate) fn l2_normalize(v: &mut [f32]) { + let norm = v.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in v.iter_mut() { + *x /= norm; + } + } +} + +/// Result of [`OnnxEmbedder::embed`]: an L2-normalized sentence vector. +#[derive(Debug, Clone)] +pub struct EmbeddingOutput { + /// L2-normalized embedding. Its length is the model's hidden size + /// (384 for `all-MiniLM-L6-v2`). + pub values: Vec, +} + +/// A loaded ONNX sentence-embedding model paired with its tokenizer. +/// +/// Construction is the slow path (parse + optimise the graph). `embed` +/// is cheap enough to call on the request hot path for short prompts. +pub struct OnnxEmbedder { + model: RunnableOnnxModel, + tokenizer: Tokenizer, +} + +impl OnnxEmbedder { + /// Load with default [`LoadOptions`] (200 MB budget, no signatures). + pub fn load(model_path: &Path, tokenizer_path: &Path) -> Result { + Self::load_with_options(model_path, tokenizer_path, &LoadOptions::default()) + } + + /// Load an embedding model + tokenizer, enforcing the size budget. + pub fn load_with_options( + model_path: &Path, + tokenizer_path: &Path, + options: &LoadOptions, + ) -> Result { + check_size_budget(model_path, "model", options.effective_model_limit())?; + check_size_budget( + tokenizer_path, + "tokenizer", + options.effective_tokenizer_limit(), + )?; + let tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|e| anyhow!("failed to load tokenizer at {tokenizer_path:?}: {e}"))?; + let model = tract_onnx::onnx() + .model_for_path(model_path) + .with_context(|| format!("failed to parse ONNX model at {model_path:?}"))? + .into_optimized() + .context("failed to optimise ONNX model")? + .into_runnable() + .context("failed to make ONNX model runnable")?; + Ok(Self { model, tokenizer }) + } + + /// Embed one text into an L2-normalized vector. + /// + /// Tokenises `text`, runs the model, mean-pools the last hidden state + /// over tokens weighted by the attention mask, then L2-normalizes so + /// the dot product of two embeddings is their cosine similarity. + pub fn embed(&self, text: &str) -> Result { + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| anyhow!("tokenizer encode failed: {e}"))?; + let ids: Vec = encoding.get_ids().iter().map(|i| *i as i64).collect(); + let mask: Vec = encoding + .get_attention_mask() + .iter() + .map(|m| *m as i64) + .collect(); + let seq_len = ids.len(); + if seq_len == 0 { + return Err(anyhow!("tokenizer produced empty encoding")); + } + + let input_ids = + tract_ndarray::Array2::from_shape_vec((1, seq_len), ids).map_err(|e| anyhow!(e))?; + let attention_mask = tract_ndarray::Array2::from_shape_vec((1, seq_len), mask.clone()) + .map_err(|e| anyhow!(e))?; + + // Route inputs by declared name, matching OnnxClassifier::classify so + // exports that take input_ids / attention_mask / token_type_ids all work. + let input_names: Vec = self + .model + .model() + .input_outlets()? + .iter() + .map(|outlet| self.model.model().node(outlet.node).name.clone()) + .collect(); + + let mut inputs: TVec = tvec!(); + for name in &input_names { + let lower = name.to_ascii_lowercase(); + if lower.contains("input_ids") || lower == "ids" { + inputs.push(input_ids.clone().into_tensor().into()); + } else if lower.contains("attention_mask") || lower.contains("mask") { + inputs.push(attention_mask.clone().into_tensor().into()); + } else if lower.contains("token_type_ids") { + let zeros: Vec = vec![0; seq_len]; + let token_type_ids = tract_ndarray::Array2::from_shape_vec((1, seq_len), zeros) + .map_err(|e| anyhow!(e))?; + inputs.push(token_type_ids.into_tensor().into()); + } else { + inputs.push(input_ids.clone().into_tensor().into()); + } + } + + let outputs = self + .model + .run(inputs) + .map_err(|e| anyhow!("ONNX inference failed: {e}"))?; + // Sentence-transformer exports put the token embeddings first: + // last_hidden_state with shape [1, seq_len, dim]. + let hidden = outputs + .into_iter() + .next() + .ok_or_else(|| anyhow!("ONNX model returned no outputs"))?; + let view = hidden + .to_array_view::() + .map_err(|e| anyhow!("output tensor was not f32: {e}"))?; + let shape = view.shape(); + if shape.len() != 3 || shape[0] != 1 { + return Err(anyhow!( + "expected [1, seq, dim] hidden state, got shape {shape:?}" + )); + } + let dim = shape[2]; + let flat: Vec = view.iter().copied().collect(); + let mut pooled = mean_pool(&flat, &mask, seq_len, dim); + l2_normalize(&mut pooled); + Ok(EmbeddingOutput { values: pooled }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mean_pool_averages_unmasked_tokens() { + // 2 tokens, dim 2. token0=[1,3] token1=[3,7], both unmasked. + let hidden = vec![1.0, 3.0, 3.0, 7.0]; + let mask = vec![1i64, 1]; + assert_eq!(mean_pool(&hidden, &mask, 2, 2), vec![2.0, 5.0]); + } + + #[test] + fn mean_pool_excludes_masked_tokens() { + let hidden = vec![1.0, 1.0, 9.0, 9.0]; // token1 is padding + let mask = vec![1i64, 0]; + assert_eq!(mean_pool(&hidden, &mask, 2, 2), vec![1.0, 1.0]); + } + + #[test] + fn mean_pool_all_masked_is_zero() { + let hidden = vec![5.0, 5.0]; + let mask = vec![0i64]; + assert_eq!(mean_pool(&hidden, &mask, 1, 2), vec![0.0, 0.0]); + } + + #[test] + fn l2_normalize_gives_unit_length() { + let mut v = vec![3.0, 4.0]; + l2_normalize(&mut v); + let norm = (v[0] * v[0] + v[1] * v[1]).sqrt(); + assert!((norm - 1.0).abs() < 1e-6); + assert!((v[0] - 0.6).abs() < 1e-6 && (v[1] - 0.8).abs() < 1e-6); + } + + #[test] + fn l2_normalize_zero_vector_unchanged() { + let mut v = vec![0.0, 0.0]; + l2_normalize(&mut v); + assert_eq!(v, vec![0.0, 0.0]); + } + + // Gated: needs a downloaded MiniLM model. Run locally with + // SBPROXY_TEST_EMBED_MODEL=/path/model.onnx + // SBPROXY_TEST_EMBED_TOKENIZER=/path/tokenizer.json + #[test] + fn embed_real_model_is_normalized_and_self_similar() { + let (Ok(m), Ok(t)) = ( + std::env::var("SBPROXY_TEST_EMBED_MODEL"), + std::env::var("SBPROXY_TEST_EMBED_TOKENIZER"), + ) else { + eprintln!("skipping: set SBPROXY_TEST_EMBED_MODEL/_TOKENIZER to run"); + return; + }; + let emb = OnnxEmbedder::load(Path::new(&m), Path::new(&t)).unwrap(); + let a = emb.embed("the cat sat on the mat").unwrap(); + let norm: f32 = a.values.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 1e-3, "vector must be L2-normalized"); + let b = emb.embed("the cat sat on the mat").unwrap(); + let cos: f32 = a.values.iter().zip(&b.values).map(|(x, y)| x * y).sum(); + assert!( + cos > 0.999, + "identical text cosine should be ~1.0, got {cos}" + ); + } +} diff --git a/crates/sbproxy-classifiers/src/known_models.rs b/crates/sbproxy-classifiers/src/known_models.rs index 30719d2a..1c799217 100644 --- a/crates/sbproxy-classifiers/src/known_models.rs +++ b/crates/sbproxy-classifiers/src/known_models.rs @@ -140,9 +140,37 @@ pub const PROMPT_INJECTION_V2_MODEL: KnownModel = KnownModel { revision_pinned_at: "2026-04-27", }; +/// Default sentence-embedding model for the AI gateway semantic cache and +/// the in-process / sidecar embedder. +/// +/// `all-MiniLM-L6-v2` is a 6-layer, 384-dim sentence-transformer under +/// Apache-2.0. It runs on the pure-Rust tract engine, is small enough to +/// cache and ship to air-gapped sites, and its quality is well-suited to +/// near-duplicate prompt detection (the semantic-cache use case). +/// +/// SHA pins are empty for the same reason as `prompt-injection-v2`: the +/// build sandbox cannot reach the upstream URL, so operators compute the +/// hash on first download. The `no_known_model_has_unpinned_sha256` test +/// stays `#[ignore]`'d until the pinning follow-up lands. +pub const ALL_MINILM_L6_V2_MODEL: KnownModel = KnownModel { + name: "all-MiniLM-L6-v2", + model_url: concat!( + "https://huggingface.co/sentence-transformers/", + "all-MiniLM-L6-v2/resolve/main/onnx/model.onnx" + ), + model_sha256: "", + tokenizer_url: concat!( + "https://huggingface.co/sentence-transformers/", + "all-MiniLM-L6-v2/resolve/main/tokenizer.json" + ), + tokenizer_sha256: "", + license: "Apache-2.0", + revision_pinned_at: "2026-06-08", +}; + /// Every entry the registry knows about. Add new pins here; tests /// assert that the array stays unique by `name`. -pub const KNOWN_MODELS: &[KnownModel] = &[PROMPT_INJECTION_V2_MODEL]; +pub const KNOWN_MODELS: &[KnownModel] = &[PROMPT_INJECTION_V2_MODEL, ALL_MINILM_L6_V2_MODEL]; static INDEX: OnceLock> = OnceLock::new(); diff --git a/crates/sbproxy-classifiers/src/lib.rs b/crates/sbproxy-classifiers/src/lib.rs index be806446..e7922e61 100644 --- a/crates/sbproxy-classifiers/src/lib.rs +++ b/crates/sbproxy-classifiers/src/lib.rs @@ -66,6 +66,9 @@ pub mod agent_classifier_types; pub mod judge_rpc; pub mod known_models; +mod embedder; +pub use embedder::{EmbeddingOutput, OnnxEmbedder}; + pub use agent_class::{ AgentClass, AgentClassCatalog, AgentId, AgentIdSource, AgentPurpose, DEFAULT_CATALOG_YAML, }; @@ -106,7 +109,7 @@ pub const MAX_MODEL_BYTES_DEFAULT: u64 = 200 * 1024 * 1024; /// Type alias for the optimised, runnable tract graph held inside /// [`OnnxClassifier`]. The full path is unwieldy and trips /// `clippy::type_complexity`. -type RunnableOnnxModel = +pub(crate) type RunnableOnnxModel = SimplePlan, Graph>>; /// A loaded ONNX classifier paired with its tokenizer. @@ -208,11 +211,11 @@ impl LoadOptions { self } - fn effective_model_limit(&self) -> u64 { + pub(crate) fn effective_model_limit(&self) -> u64 { self.max_model_bytes.unwrap_or(MAX_MODEL_BYTES_DEFAULT) } - fn effective_tokenizer_limit(&self) -> u64 { + pub(crate) fn effective_tokenizer_limit(&self) -> u64 { self.max_tokenizer_bytes.unwrap_or(MAX_MODEL_BYTES_DEFAULT) } } @@ -670,7 +673,7 @@ fn ensure_cached_file( /// Reject `path` if its on-disk size exceeds `max_bytes`. `max_bytes` /// of `0` is treated as "no limit" so callers that explicitly opt out /// of the budget still get the rest of the load-time pipeline. -fn check_size_budget(path: &Path, kind: &str, max_bytes: u64) -> Result<()> { +pub(crate) fn check_size_budget(path: &Path, kind: &str, max_bytes: u64) -> Result<()> { if max_bytes == 0 { return Ok(()); } diff --git a/crates/sbproxy-core/Cargo.toml b/crates/sbproxy-core/Cargo.toml index 45094e8a..27f83bd3 100644 --- a/crates/sbproxy-core/Cargo.toml +++ b/crates/sbproxy-core/Cargo.toml @@ -49,6 +49,11 @@ tls-fingerprint = [ # Real views land in follow-up tickets; this feature only wires # the scaffold mount. embed-admin-ui = ["dep:include_dir"] +# WOR-1235: run the semantic-cache embedding model in-process (tract) for a +# single-binary deploy. Pulls in the ONNX engine. The `sbproxy` binary turns +# this on by default; library consumers leave it off (the `source: inprocess` +# config then returns a clear error and the cache treats it as a miss). +inprocess-embed = ["dep:sbproxy-classifiers"] [dependencies] sbproxy-plugin.workspace = true diff --git a/crates/sbproxy-core/src/server/ai_dispatch.rs b/crates/sbproxy-core/src/server/ai_dispatch.rs index 124425ee..bedb1d72 100644 --- a/crates/sbproxy-core/src/server/ai_dispatch.rs +++ b/crates/sbproxy-core/src/server/ai_dispatch.rs @@ -68,6 +68,11 @@ pub(super) async fn handle_ai_proxy( method = %method_str, "AI proxy: per-surface rate limit hit; returning 429" ); + sbproxy_ai::tracing_spans::record_error( + &tracing::Span::current(), + sbproxy_ai::tracing_spans::error_type::RATE_LIMITED, + "per-surface rate limit exceeded", + ); send_error(session, 429, "per-surface rate limit exceeded").await?; return Ok(()); } @@ -901,6 +906,11 @@ pub(super) async fn handle_ai_proxy( ); let retry = rej.retry_after_secs.to_string(); let extra: Option<(&str, &str)> = Some(("retry-after", &retry)); + sbproxy_ai::tracing_spans::record_error( + &tracing::Span::current(), + sbproxy_ai::tracing_spans::error_type::RATE_LIMITED, + "model rate limit exceeded", + ); send_response_with_extra( session, 429, @@ -927,6 +937,20 @@ pub(super) async fn handle_ai_proxy( // and the intent detection hook so we do not re-parse the body twice. let extracted_prompt = extract_prompt_text(&body); + // WOR-1228: emit the prompt as the OpenInference `input.value` span + // attribute when the origin opts into content capture. Off by default; + // the text is routed through the always-on secret redactor and the + // origin's PII redactor (if any) before it lands on the span, so a + // trace backend never sees raw secrets or PII. + if config.trace_content && !extracted_prompt.is_empty() { + let secrets_redacted = sbproxy_observe::redact::redact_secrets(&extracted_prompt); + let redacted = match config.pii_redactor() { + Some(redactor) => redactor.redact(&secrets_redacted).into_owned(), + None => secrets_redacted, + }; + sbproxy_ai::tracing_spans::record_input_content(&tracing::Span::current(), &redacted); + } + if let Some(hook) = pipeline.hooks.prompt_classifier.as_ref().cloned() { if !extracted_prompt.is_empty() { let model_id = if model.is_empty() { @@ -1021,6 +1045,11 @@ pub(super) async fn handle_ai_proxy( reason = %block.reason, "AI proxy: input guardrail blocked request" ); + sbproxy_ai::tracing_spans::record_error( + &tracing::Span::current(), + sbproxy_ai::tracing_spans::error_type::GUARDRAIL_BLOCKED, + &block.reason, + ); let error_body = serde_json::json!({ "error": { "message": block.reason, @@ -1200,101 +1229,164 @@ pub(super) async fn handle_ai_proxy( req_header_value(session, "authorization").as_deref(), ); if !extracted_prompt.is_empty() { - match config.providers.iter().find(|p| p.name == cache.provider()) { - Some(provider) => { - let ai_client = AI_CLIENT.load_full(); - match sbproxy_ai::semantic_cache::compute_embedding( - &ai_client, - provider, - cache.model(), - &extracted_prompt, - ) - .await + // WOR-1223: vectorize the prompt via the configured source. + // Provider hits the embedding API (costs money, egresses the + // prompt); sidecar uses the local classifier sidecar (free, no + // egress). Any error falls through to an uncached upstream call. + let query_vec_result: anyhow::Result> = match cache.source() { + sbproxy_ai::semantic_cache::EmbeddingSource::Provider => { + match config.providers.iter().find(|p| p.name == cache.provider()) { + Some(provider) => { + let ai_client = AI_CLIENT.load_full(); + sbproxy_ai::semantic_cache::compute_embedding( + &ai_client, + provider, + cache.model(), + &extracted_prompt, + ) + .await + } + None => Err(anyhow::anyhow!( + "semantic cache embedding provider {} not found in providers list", + cache.provider() + )), + } + } + sbproxy_ai::semantic_cache::EmbeddingSource::Sidecar => { + match cache.sidecar_config() { + Some(sc) => { + sbproxy_ai::semantic_cache::compute_embedding_sidecar( + sc, + &extracted_prompt, + ) + .await + } + None => Err(anyhow::anyhow!( + "semantic cache sidecar source has no sidecar config" + )), + } + } + sbproxy_ai::semantic_cache::EmbeddingSource::Inprocess => { + #[cfg(feature = "inprocess-embed")] { - Ok(query_vec) => { - if let Some(hit) = cache.lookup(&query_vec, &cache_scope) { - sbproxy_ai::ai_metrics::record_cache_result( - cache.provider(), - "semantic", - true, - ); - sbproxy_ai::ai_metrics::record_semantic_similarity( - cache.provider(), - hit.score, - ); - debug!( - origin = %hostname, - score = hit.score, - status = hit.response.status, - "AI proxy: embedding semantic cache HIT; replaying" - ); - let mut header = pingora_http::ResponseHeader::build( - hit.response.status, - Some(hit.response.headers.len() + 1), - ) - .map_err(|e| { - Error::because( - ErrorType::InternalError, - "embedding cache: failed to build response header", - e, - ) - })?; - for (name, value) in &hit.response.headers { - if name == "transfer-encoding" || name == "connection" { - continue; - } - let _ = header.insert_header(name.clone(), value.clone()); - } - let _ = header.insert_header("x-semcache", "HIT"); - let body = bytes::Bytes::from(hit.response.body); - // WOR-1094: a cache hit is a zero-cost - // ledger transaction, not an absent one. - // Record the served tokens under the - // cache_read dimension so the hit still - // shows up as savings. - crate::server::ai_support::record_cache_hit_savings( - cache.provider(), - cache.model(), - surface_label, - &body, - &ctx.attribution_tags, - ); - session - .write_response_header(Box::new(header), false) - .await?; - session.write_response_body(Some(body), true).await?; - return Ok(()); - } - sbproxy_ai::ai_metrics::record_cache_result( - cache.provider(), - "semantic", - false, - ); - embed_miss = Some(( - std::sync::Arc::clone(cache), - sbproxy_ai::EmbeddingCache::prompt_key( - &cache_scope, - &extracted_prompt, - ), - query_vec, - cache_scope, - )); + match cache.inprocess_config() { + Some(cfg) => crate::server::ai_support::inprocess_embed( + cfg, + &extracted_prompt, + ), + None => Err(anyhow::anyhow!( + "inprocess embedding source has no inprocess config" + )), } - Err(e) => { - warn!( - origin = %hostname, - error = %e, - "AI proxy: embedding cache lookup failed (fail-open)" - ); + } + #[cfg(not(feature = "inprocess-embed"))] + { + Err(anyhow::anyhow!( + "in-process embedding not compiled in this build; rebuild with \ + --features inprocess-embed or use source: sidecar" + )) + } + } + }; + let source_label: &str = match cache.source() { + sbproxy_ai::semantic_cache::EmbeddingSource::Provider => "provider", + sbproxy_ai::semantic_cache::EmbeddingSource::Sidecar => "sidecar", + sbproxy_ai::semantic_cache::EmbeddingSource::Inprocess => "inprocess", + }; + match query_vec_result { + Ok(query_vec) => { + if let Some(hit) = cache.lookup(&query_vec, &cache_scope) { + sbproxy_ai::ai_metrics::record_cache_result( + cache.provider(), + "semantic", + true, + ); + sbproxy_observe::metrics::record_semantic_cache( + ctx.tenant_id.as_str(), + hostname, + source_label, + "hit", + ); + sbproxy_ai::ai_metrics::record_semantic_similarity( + cache.provider(), + hit.score, + ); + debug!( + tenant = %ctx.tenant_id, + origin = %hostname, + score = hit.score, + status = hit.response.status, + "AI proxy: embedding semantic cache HIT; replaying" + ); + let mut header = pingora_http::ResponseHeader::build( + hit.response.status, + Some(hit.response.headers.len() + 1), + ) + .map_err(|e| { + Error::because( + ErrorType::InternalError, + "embedding cache: failed to build response header", + e, + ) + })?; + for (name, value) in &hit.response.headers { + if name == "transfer-encoding" || name == "connection" { + continue; + } + let _ = header.insert_header(name.clone(), value.clone()); } + let _ = header.insert_header("x-semcache", "HIT"); + let body = bytes::Bytes::from(hit.response.body); + // WOR-1094: a cache hit is a zero-cost + // ledger transaction, not an absent one. + // Record the served tokens under the + // cache_read dimension so the hit still + // shows up as savings. + crate::server::ai_support::record_cache_hit_savings( + ctx.tenant_id.as_str(), + hostname, + cache.provider(), + cache.model(), + surface_label, + &body, + &ctx.attribution_tags, + ); + session + .write_response_header(Box::new(header), false) + .await?; + session.write_response_body(Some(body), true).await?; + return Ok(()); } + sbproxy_ai::ai_metrics::record_cache_result( + cache.provider(), + "semantic", + false, + ); + sbproxy_observe::metrics::record_semantic_cache( + ctx.tenant_id.as_str(), + hostname, + source_label, + "miss", + ); + embed_miss = Some(( + std::sync::Arc::clone(cache), + sbproxy_ai::EmbeddingCache::prompt_key(&cache_scope, &extracted_prompt), + query_vec, + cache_scope, + )); } - None => { + Err(e) => { + sbproxy_observe::metrics::record_semantic_cache( + ctx.tenant_id.as_str(), + hostname, + source_label, + "error", + ); warn!( + tenant = %ctx.tenant_id, origin = %hostname, - provider = %cache.provider(), - "AI proxy: semantic cache embedding provider not found in \ - providers list; skipping cache" + error = %e, + "AI proxy: embedding cache lookup failed (fail-open)" ); } } diff --git a/crates/sbproxy-core/src/server/ai_support.rs b/crates/sbproxy-core/src/server/ai_support.rs index b77144ec..201e39f7 100644 --- a/crates/sbproxy-core/src/server/ai_support.rs +++ b/crates/sbproxy-core/src/server/ai_support.rs @@ -772,6 +772,10 @@ pub(super) fn emit_ai_billing_event( 0, cost_usd, ); + // WOR-1229: stamp the derived USD cost onto the AI request span so trace + // backends show spend per generation. This is the same choke point the + // cost metric uses, so the span and the metric agree. + sbproxy_ai::tracing_spans::record_cost_usd(&tracing::Span::current(), cost_usd); // WOR-1095: realtime + audio surfaces consume seconds, not tokens, // and realtime has no catalogue price, so the token / cost @@ -850,6 +854,8 @@ pub(super) fn resolve_attribution_tags( /// and its `model` field (falling back to `fallback_model`) gives the /// model label. pub(super) fn record_cache_hit_savings( + tenant: &str, + origin: &str, provider: &str, fallback_model: &str, surface: &str, @@ -889,6 +895,93 @@ pub(super) fn record_cache_hit_savings( 0, 0.0, ); + // WOR-1225: SOTA usage tracking. Attribute the tokens and cost this hit + // avoided (the upstream call that did not happen), using the same cost + // table as spent cost so saved and spent reconcile. + let cost_micros = + (sbproxy_ai::estimate_cost(model, prompt, completion) * 1_000_000.0).max(0.0) as u64; + sbproxy_observe::metrics::record_cache_savings( + tenant, + origin, + model, + prompt, + completion, + cost_micros, + ); +} + +/// WOR-1235: compute a prompt embedding with an in-process tract embedder +/// for the semantic cache (`source: inprocess`). The embedder is loaded once +/// from the config's `model_path` + `tokenizer_path` (with the +/// `max_model_bytes` guard) and held for the process lifetime. Available only +/// when built with the `inprocess-embed` feature; otherwise the +/// `EmbeddingSource::Inprocess` arm returns a clear error and the cache treats +/// the lookup as a miss. +#[cfg(feature = "inprocess-embed")] +pub(super) fn inprocess_embed( + cfg: &sbproxy_ai::semantic_cache::InprocessEmbeddingConfig, + text: &str, +) -> anyhow::Result> { + use std::sync::{Arc, OnceLock}; + static EMBEDDER: OnceLock>> = OnceLock::new(); + let started = std::time::Instant::now(); + let embedder = EMBEDDER.get_or_init(|| { + let (Some(model_path), Some(tokenizer_path)) = + (cfg.model_path.as_ref(), cfg.tokenizer_path.as_ref()) + else { + warn!( + "inprocess embedding source requires model_path and tokenizer_path; \ + the cache will treat lookups as misses" + ); + return None; + }; + let mut options = sbproxy_classifiers::LoadOptions::default(); + if let Some(bytes) = cfg.max_model_bytes { + options = options.with_max_model_bytes(bytes); + } + match sbproxy_classifiers::OnnxEmbedder::load_with_options( + std::path::Path::new(model_path), + std::path::Path::new(tokenizer_path), + &options, + ) { + Ok(e) => Some(Arc::new(e)), + Err(e) => { + warn!(error = %e, "failed to load in-process embedder"); + None + } + } + }); + let model_label = if cfg.model.is_empty() { + "inprocess" + } else { + cfg.model.as_str() + }; + match embedder { + Some(e) => { + let out = e.embed(text); + let result = if out.is_ok() { "ok" } else { "error" }; + sbproxy_observe::metrics::record_inference( + "embed", + "inprocess", + model_label, + result, + started.elapsed().as_secs_f64(), + ); + Ok(out?.values) + } + None => { + sbproxy_observe::metrics::record_inference( + "embed", + "inprocess", + model_label, + "error", + started.elapsed().as_secs_f64(), + ); + Err(anyhow::anyhow!( + "in-process embedder not loaded; check model_path and tokenizer_path" + )) + } + } } pub(super) fn record_budget_usage( diff --git a/crates/sbproxy-modules/src/policy/prompt_injection_v2/inprocess.rs b/crates/sbproxy-modules/src/policy/prompt_injection_v2/inprocess.rs new file mode 100644 index 00000000..e353475f --- /dev/null +++ b/crates/sbproxy-modules/src/policy/prompt_injection_v2/inprocess.rs @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +//! In-process ONNX detector for `prompt_injection_v2` (opt-in). +//! +//! Runs the tract ONNX classifier inside the proxy address space. WOR-612 +//! removed the original in-process detector because an unsandboxed model +//! parse could OOM the proxy; this brings it back only behind an explicit +//! `detector: "inprocess"` opt-in plus a hard `max_model_bytes` size guard, +//! and the operator supplies the model + tokenizer paths. Operators who +//! want process isolation should still prefer `detector: "sidecar"`. + +use std::path::Path; +use std::sync::Arc; + +use sbproxy_classifiers::{LoadOptions, OnnxClassifier}; +use serde::Deserialize; + +use super::detector::{DetectionLabel, DetectionResult, Detector}; + +/// Config name selecting this detector (`detector: "inprocess"`). +pub const INPROCESS_DETECTOR_NAME: &str = "inprocess"; + +const DEFAULT_INJECTION_LABEL: &str = "INJECTION"; +const DEFAULT_THRESHOLD: f64 = 0.5; + +/// Map a `[0,1]` injection score onto the v2 label vocabulary. Same +/// cutoffs as the sidecar detector so the two report identically. +fn classify_score(score: f64, threshold: f64) -> DetectionLabel { + if score >= threshold { + DetectionLabel::Injection + } else if score >= 0.3 { + DetectionLabel::Suspicious + } else { + DetectionLabel::Clean + } +} + +/// Deserializable `detector_config` block for the in-process detector. +#[derive(Debug, Deserialize)] +struct InprocessDetectorConfig { + /// Path to the ONNX model file. + model_path: String, + /// Path to the tokenizer.json file. + tokenizer_path: String, + /// Optional class labels indexed by output class. When omitted, the + /// model's argmax is reported as `class_`. + #[serde(default)] + labels: Option>, + /// Label name (case-insensitive) treated as the injection verdict. + #[serde(default = "default_injection_label")] + injection_label: String, + /// Score at or above which a verdict is labelled `injection`. + #[serde(default = "default_threshold")] + threshold: f64, + /// Hard upper bound on the ONNX model file size in bytes. None uses + /// the engine default (200 MB). This is the guard that bounds the + /// OOM risk WOR-612 flagged. + #[serde(default)] + max_model_bytes: Option, +} + +fn default_injection_label() -> String { + DEFAULT_INJECTION_LABEL.to_string() +} +fn default_threshold() -> f64 { + DEFAULT_THRESHOLD +} + +/// Detector that runs ONNX classification in-process via tract. +pub struct InprocessDetector { + classifier: OnnxClassifier, + injection_label: String, + threshold: f64, + name: &'static str, +} + +impl InprocessDetector { + /// Build from the policy's `detector_config` block. Loads the model at + /// construction time (the slow path) so `detect` stays cheap; the + /// size guard is enforced before the graph is parsed. + pub fn from_config(value: &serde_json::Value) -> anyhow::Result> { + let cfg: InprocessDetectorConfig = serde_json::from_value(value.clone()) + .map_err(|e| anyhow::anyhow!("inprocess detector config: {e}"))?; + let mut options = LoadOptions::default(); + if let Some(bytes) = cfg.max_model_bytes { + options = options.with_max_model_bytes(bytes); + } + let classifier = OnnxClassifier::load_with_options( + Path::new(&cfg.model_path), + Path::new(&cfg.tokenizer_path), + cfg.labels, + &options, + ) + .map_err(|e| anyhow::anyhow!("inprocess detector: {e}"))?; + Ok(Arc::new(Self { + classifier, + injection_label: cfg.injection_label, + threshold: cfg.threshold, + name: INPROCESS_DETECTOR_NAME, + })) + } +} + +impl Detector for InprocessDetector { + fn detect(&self, prompt: &str) -> DetectionResult { + match self.classifier.classify(prompt) { + Ok(output) => { + let score = output.score as f64; + let is_injection_label = output.label.eq_ignore_ascii_case(&self.injection_label); + // A non-injection top label is read as confidence the prompt + // is benign, so invert it (mirrors the sidecar detector). + let (score_for_policy, label) = if is_injection_label { + (score, classify_score(score, self.threshold)) + } else { + (1.0 - score, classify_score(1.0 - score, self.threshold)) + }; + DetectionResult { + score: score_for_policy, + label, + reason: Some(format!( + "inprocess label={} score={:.3}", + output.label, output.score + )), + } + } + Err(e) => { + // Inference failure fails open (clean) so a model hiccup never + // wedges the request path; operators who want fail-closed use + // the sidecar detector's policy. + tracing::warn!(error = %e, "inprocess prompt-injection inference failed; failing open"); + DetectionResult::clean() + } + } + } + + fn name(&self) -> &str { + self.name + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_config_requires_model_and_tokenizer_paths() { + // Missing both paths: a config error, not a panic. + let err = match InprocessDetector::from_config(&serde_json::json!({})) { + Ok(_) => panic!("config without paths must fail"), + Err(e) => e, + }; + assert!(err.to_string().contains("inprocess detector config")); + } + + #[test] + fn from_config_missing_model_file_errors() { + let err = match InprocessDetector::from_config(&serde_json::json!({ + "model_path": "/nonexistent/model.onnx", + "tokenizer_path": "/nonexistent/tokenizer.json" + })) { + Ok(_) => panic!("nonexistent model must fail at load"), + Err(e) => e, + }; + assert!(err.to_string().contains("inprocess detector")); + } + + #[test] + fn classify_score_maps_to_the_v2_vocabulary() { + // At or above threshold => injection. + assert_eq!(classify_score(0.9, 0.5), DetectionLabel::Injection); + assert_eq!(classify_score(0.5, 0.5), DetectionLabel::Injection); + // In [0.3, threshold) => suspicious. + assert_eq!(classify_score(0.49, 0.5), DetectionLabel::Suspicious); + assert_eq!(classify_score(0.3, 0.5), DetectionLabel::Suspicious); + // Below 0.3 => clean. + assert_eq!(classify_score(0.29, 0.5), DetectionLabel::Clean); + assert_eq!(classify_score(0.0, 0.5), DetectionLabel::Clean); + } + + #[test] + fn classify_score_threshold_is_inclusive_and_configurable() { + // A higher threshold widens the suspicious band. + assert_eq!(classify_score(0.85, 0.9), DetectionLabel::Suspicious); + assert_eq!(classify_score(0.9, 0.9), DetectionLabel::Injection); + // A low threshold collapses suspicious: 0.3 still suspicious, 0.31 injects. + assert_eq!(classify_score(0.31, 0.31), DetectionLabel::Injection); + } + + #[test] + fn default_injection_label_and_threshold_are_stable() { + assert_eq!(DEFAULT_INJECTION_LABEL, "INJECTION"); + assert_eq!(default_injection_label(), "INJECTION"); + assert_eq!(default_threshold(), 0.5); + } + + #[test] + fn from_config_rejects_paths_only_partially_given() { + // model_path without tokenizer_path is a config error, not a panic. + let err = InprocessDetector::from_config(&serde_json::json!({ + "model_path": "/some/model.onnx" + })) + .err() + .expect("partial paths must fail"); + assert!(err.to_string().contains("inprocess detector config")); + } +} diff --git a/crates/sbproxy-modules/src/policy/prompt_injection_v2/mod.rs b/crates/sbproxy-modules/src/policy/prompt_injection_v2/mod.rs index 6b0fbecf..0a2aef1f 100644 --- a/crates/sbproxy-modules/src/policy/prompt_injection_v2/mod.rs +++ b/crates/sbproxy-modules/src/policy/prompt_injection_v2/mod.rs @@ -16,6 +16,7 @@ mod body_aware; mod detector; mod heuristic; +mod inprocess; mod sidecar; pub use body_aware::{ @@ -27,6 +28,7 @@ pub use detector::{ DetectorFactory, }; pub use heuristic::{HeuristicDetector, HEURISTIC_DETECTOR_NAME}; +pub use inprocess::{InprocessDetector, INPROCESS_DETECTOR_NAME}; pub use sidecar::{SidecarDetector, SIDECAR_DETECTOR_NAME}; use std::sync::Arc; @@ -220,6 +222,10 @@ impl PromptInjectionV2Policy { )); } else if raw.detector == sidecar::SIDECAR_DETECTOR_NAME { sidecar::SidecarDetector::from_config(&raw.detector_config)? + } else if raw.detector == inprocess::INPROCESS_DETECTOR_NAME { + // WOR-1224: opt-in in-process ONNX classify. Bounded by a + // max_model_bytes guard; prefer detector: "sidecar" for isolation. + inprocess::InprocessDetector::from_config(&raw.detector_config)? } else { lookup_detector(&raw.detector).ok_or_else(|| { anyhow!( diff --git a/crates/sbproxy-observe/src/export/otlp_grpc.rs b/crates/sbproxy-observe/src/export/otlp_grpc.rs index 812ff144..a97811ff 100644 --- a/crates/sbproxy-observe/src/export/otlp_grpc.rs +++ b/crates/sbproxy-observe/src/export/otlp_grpc.rs @@ -58,6 +58,8 @@ pub fn init_grpc_pipeline(config: &OtlpGrpcConfig) -> anyhow::Result<()> { service_name: config.service_name.clone(), sample_rate: None, always_sample_errors: true, + keep_over_budget_usd: None, + keep_slower_than_secs: None, propagation: None, resource_attrs: std::collections::BTreeMap::new(), export_metrics: false, diff --git a/crates/sbproxy-observe/src/metrics.rs b/crates/sbproxy-observe/src/metrics.rs index 49ff9ba6..8c2d96c9 100644 --- a/crates/sbproxy-observe/src/metrics.rs +++ b/crates/sbproxy-observe/src/metrics.rs @@ -265,6 +265,25 @@ pub struct ProxyMetrics { /// Counter `sbproxy_ai_tokens_total` of AI token usage labelled by hostname, provider, and direction. pub ai_tokens_total: IntCounterVec, + // --- Local inference + semantic cache (WOR-1225) --- + /// Counter `sbproxy_semantic_cache_results_total` of semantic-cache + /// outcomes labelled by tenant, origin, embedding source, and result. + pub semantic_cache_results: IntCounterVec, + /// Counter `sbproxy_inference_requests_total` of local inference calls + /// labelled by kind (embed|classify), backend (sidecar|inprocess), + /// model, and result (ok|error). + pub inference_requests: IntCounterVec, + /// Histogram `sbproxy_inference_duration_seconds` of local inference + /// latency labelled by kind, backend, and model. + pub inference_duration: HistogramVec, + /// Counter `sbproxy_ai_tokens_saved_total` of tokens a semantic-cache + /// hit avoided, labelled by tenant, origin, model, and kind + /// (prompt|completion). + pub ai_tokens_saved: IntCounterVec, + /// Counter `sbproxy_ai_cost_saved_micros_total` of micro-USD a + /// semantic-cache hit avoided, labelled by tenant, origin, and model. + pub ai_cost_saved_micros: IntCounterVec, + // --- Per-origin metrics (Sprint 1A) --- /// Total HTTP requests with origin, method, and status labels. pub per_origin_requests_total: CounterVec, @@ -401,6 +420,57 @@ impl ProxyMetrics { ) .unwrap(); + // --- Local inference + semantic cache (WOR-1225) --- + + let semantic_cache_results = IntCounterVec::new( + Opts::new( + "sbproxy_semantic_cache_results_total", + "Semantic-cache hit/miss/error counts", + ), + // tenant: multi-tenant attribution; source: provider|sidecar|inprocess; result: hit|miss|error + &["tenant", "origin", "source", "result"], + ) + .unwrap(); + + let inference_requests = IntCounterVec::new( + Opts::new( + "sbproxy_inference_requests_total", + "Local inference call counts", + ), + &["kind", "backend", "model", "result"], // kind: embed|classify; result: ok|error + ) + .unwrap(); + + let inference_duration = HistogramVec::new( + prometheus::HistogramOpts::new( + "sbproxy_inference_duration_seconds", + "Local inference latency in seconds", + ) + .buckets(vec![ + 0.0005, 0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, + ]), + &["kind", "backend", "model"], + ) + .unwrap(); + + let ai_tokens_saved = IntCounterVec::new( + Opts::new( + "sbproxy_ai_tokens_saved_total", + "Tokens avoided by a semantic-cache hit", + ), + &["tenant", "origin", "model", "kind"], // kind: prompt|completion + ) + .unwrap(); + + let ai_cost_saved_micros = IntCounterVec::new( + Opts::new( + "sbproxy_ai_cost_saved_micros_total", + "Micro-USD avoided by a semantic-cache hit", + ), + &["tenant", "origin", "model"], + ) + .unwrap(); + // --- Per-origin metrics (Sprint 1A) --- let per_origin_requests_total = CounterVec::new( @@ -590,6 +660,21 @@ impl ProxyMetrics { registry .register(Box::new(ai_tokens_total.clone())) .unwrap(); + registry + .register(Box::new(semantic_cache_results.clone())) + .unwrap(); + registry + .register(Box::new(inference_requests.clone())) + .unwrap(); + registry + .register(Box::new(inference_duration.clone())) + .unwrap(); + registry + .register(Box::new(ai_tokens_saved.clone())) + .unwrap(); + registry + .register(Box::new(ai_cost_saved_micros.clone())) + .unwrap(); registry .register(Box::new(per_origin_requests_total.clone())) .unwrap(); @@ -645,6 +730,11 @@ impl ProxyMetrics { active_connections, cache_hits, ai_tokens_total, + semantic_cache_results, + inference_requests, + inference_duration, + ai_tokens_saved, + ai_cost_saved_micros, per_origin_requests_total, per_origin_request_duration, per_origin_active_connections, @@ -900,6 +990,72 @@ pub fn record_phase_duration(phase: &str, origin: &str, duration_secs: f64) { ); } +/// Record a semantic-cache outcome (WOR-1225), attributed per tenant. +/// `source` is provider|sidecar|inprocess; `result` is hit|miss|error. +pub fn record_semantic_cache(tenant: &str, origin: &str, source: &str, result: &str) { + let tenant = sanitize_label("tenant", tenant); + let origin = sanitize_label("origin", origin); + metrics() + .semantic_cache_results + .with_label_values(&[tenant.as_str(), origin.as_str(), source, result]) + .inc(); +} + +/// Record a local inference call and its latency (WOR-1225). `kind` is +/// embed|classify; `backend` is sidecar|inprocess; `result` is ok|error. +pub fn record_inference(kind: &str, backend: &str, model: &str, result: &str, duration_secs: f64) { + let model = sanitize_label("model", model); + metrics() + .inference_requests + .with_label_values(&[kind, backend, model.as_str(), result]) + .inc(); + if duration_secs > 0.0 { + metrics() + .inference_duration + .with_label_values(&[kind, backend, model.as_str()]) + .observe(duration_secs); + } +} + +/// Attribute the tokens and cost a semantic-cache hit avoided (WOR-1225): +/// the upstream call that did not happen. This is the value-delivered side +/// of usage tracking, so saved cost uses the same cost table as spent cost. +pub fn record_cache_savings( + tenant: &str, + origin: &str, + model: &str, + prompt_tokens: u64, + completion_tokens: u64, + cost_micros: u64, +) { + let tenant = sanitize_label("tenant", tenant); + let origin = sanitize_label("origin", origin); + let model = sanitize_label("model", model); + if prompt_tokens > 0 { + metrics() + .ai_tokens_saved + .with_label_values(&[tenant.as_str(), origin.as_str(), model.as_str(), "prompt"]) + .inc_by(prompt_tokens); + } + if completion_tokens > 0 { + metrics() + .ai_tokens_saved + .with_label_values(&[ + tenant.as_str(), + origin.as_str(), + model.as_str(), + "completion", + ]) + .inc_by(completion_tokens); + } + if cost_micros > 0 { + metrics() + .ai_cost_saved_micros + .with_label_values(&[tenant.as_str(), origin.as_str(), model.as_str()]) + .inc_by(cost_micros); + } +} + /// Record a policy trigger (allow or deny) for an origin. /// /// Legacy entry point: stamps the per-agent labels with the empty @@ -2659,6 +2815,44 @@ mod tests { // Helper functions that call metrics() use the global instance, so those // tests verify the global registry path. + #[test] + fn local_inference_and_savings_metrics_registered() { + let m = ProxyMetrics::new(); + m.semantic_cache_results + .with_label_values(&["acme", "o", "sidecar", "hit"]) + .inc(); + m.inference_requests + .with_label_values(&["embed", "sidecar", "all-MiniLM-L6-v2", "ok"]) + .inc(); + m.inference_duration + .with_label_values(&["embed", "sidecar", "all-MiniLM-L6-v2"]) + .observe(0.001); + m.ai_tokens_saved + .with_label_values(&["acme", "o", "gpt-4o", "prompt"]) + .inc_by(120); + m.ai_cost_saved_micros + .with_label_values(&["acme", "o", "gpt-4o"]) + .inc_by(900); + let names: Vec = m + .registry + .gather() + .iter() + .map(|f| f.name().to_string()) + .collect(); + for expected in [ + "sbproxy_semantic_cache_results_total", + "sbproxy_inference_requests_total", + "sbproxy_inference_duration_seconds", + "sbproxy_ai_tokens_saved_total", + "sbproxy_ai_cost_saved_micros_total", + ] { + assert!( + names.iter().any(|n| n == expected), + "missing metric {expected}" + ); + } + } + #[test] fn test_increment_requests() { let m = ProxyMetrics::new(); diff --git a/crates/sbproxy-observe/src/telemetry.rs b/crates/sbproxy-observe/src/telemetry.rs index ad8957f0..3cef8847 100644 --- a/crates/sbproxy-observe/src/telemetry.rs +++ b/crates/sbproxy-observe/src/telemetry.rs @@ -71,10 +71,22 @@ pub struct TelemetryConfig { #[serde(default)] pub sample_rate: Option, /// When `true`, every 5xx / policy-block / ledger-denial root span - /// is sampled at 100% even if the head ratio would have dropped it. + /// is kept at 100% even if the head ratio would have dropped it. /// Default `true`. #[serde(default = "default_always_sample_errors")] pub always_sample_errors: bool, + /// WOR-1230: keep any trace whose derived USD cost is at or above this + /// threshold, regardless of the head ratio. `None` disables the + /// cost-based keep. Cost is known at request end, so this is consulted + /// by the tail-sampling policy (the reference collector), not the + /// head sampler. + #[serde(default)] + pub keep_over_budget_usd: Option, + /// WOR-1230: keep any trace whose wall-clock duration is at or above + /// this many seconds, regardless of the head ratio. `None` disables the + /// latency-based keep. Like cost, a tail-sampling concern. + #[serde(default)] + pub keep_slower_than_secs: Option, /// Propagation format: `"w3c"` (default), `"b3"`, or `"jaeger"`. /// Only ships W3C; the other variants land in a follow-up. #[serde(default)] @@ -125,6 +137,8 @@ impl Default for TelemetryConfig { service_name: default_service_name(), sample_rate: None, always_sample_errors: true, + keep_over_budget_usd: None, + keep_slower_than_secs: None, propagation: None, resource_attrs: std::collections::BTreeMap::new(), export_metrics: false, @@ -133,6 +147,31 @@ impl Default for TelemetryConfig { } } +/// WOR-1230: cost-aware tail-sampling decision for an AI trace. +/// +/// Head sampling (ParentBased + TraceIdRatio, configured via +/// `sample_rate`) decides at span start, before the outcome is known. +/// Whether a finished trace should be kept regardless of that ratio, +/// because it errored, cost over a budget, or ran slow, is a tail +/// decision: it is evaluated at request end and applied by the reference +/// collector's tail-sampling policy. This pure helper is the single +/// source of truth for that decision so the proxy and the collector +/// policy agree. +/// +/// Returns `true` when the trace should be force-kept. +pub fn should_force_sample( + is_error: bool, + cost_usd: f64, + latency_secs: f64, + always_sample_errors: bool, + keep_over_budget_usd: Option, + keep_slower_than_secs: Option, +) -> bool { + (always_sample_errors && is_error) + || keep_over_budget_usd.is_some_and(|budget| cost_usd >= budget) + || keep_slower_than_secs.is_some_and(|threshold| latency_secs >= threshold) +} + // --- OTLP exporter --- /// Initialise the OTLP tracing pipeline. @@ -660,6 +699,53 @@ mod tests { assert!(config.endpoint.is_none()); assert!(config.sample_rate.is_none()); assert!(config.propagation.is_none()); + // WOR-1230: errors kept by default; cost/latency keeps off unless set. + assert!(config.always_sample_errors); + assert!(config.keep_over_budget_usd.is_none()); + assert!(config.keep_slower_than_secs.is_none()); + } + + #[test] + fn force_sample_keeps_errors_when_enabled() { + assert!(should_force_sample(true, 0.0, 0.0, true, None, None)); + // Disabled error keep: an error alone does not force a keep. + assert!(!should_force_sample(true, 0.0, 0.0, false, None, None)); + } + + #[test] + fn force_sample_keeps_over_budget_and_slow() { + // Over the cost budget. + assert!(should_force_sample( + false, + 0.05, + 0.0, + true, + Some(0.01), + None + )); + assert!(!should_force_sample( + false, + 0.005, + 0.0, + true, + Some(0.01), + None + )); + // Slower than the latency threshold. + assert!(should_force_sample(false, 0.0, 2.0, true, None, Some(1.0))); + assert!(!should_force_sample(false, 0.0, 0.5, true, None, Some(1.0))); + } + + #[test] + fn force_sample_is_false_for_a_cheap_fast_success() { + assert!(!should_force_sample( + false, + 0.001, + 0.05, + true, + Some(1.0), + Some(5.0) + )); } #[test] diff --git a/crates/sbproxy/Cargo.toml b/crates/sbproxy/Cargo.toml index 9a6405f6..f9eebe4a 100644 --- a/crates/sbproxy/Cargo.toml +++ b/crates/sbproxy/Cargo.toml @@ -38,8 +38,12 @@ default = [ "licensing-tdmrep", "llms-txt", "tls-fingerprint", + "inprocess-embed", ] tiered-pricing = ["sbproxy-modules/tiered-pricing"] +# WOR-1235: in-process semantic-cache embedder (single-binary). Forwards to +# sbproxy-core/inprocess-embed. +inprocess-embed = ["sbproxy-core/inprocess-embed"] agent-class = ["sbproxy-core/agent-class"] http-ledger = ["sbproxy-modules/http-ledger"] content-negotiate = ["sbproxy-modules/content-negotiate"] diff --git a/deny.toml b/deny.toml index 41e16642..2aa60f4f 100644 --- a/deny.toml +++ b/deny.toml @@ -118,6 +118,13 @@ ignore = [ # transitive; revisit when tract-onnx and tokenizers update. "RUSTSEC-2024-0436", + # RUSTSEC-2026-0173: proc-macro-error2 (a maintained fork of the older + # proc-macro-error) is now itself unmaintained. Pulled in transitively as a + # build-time proc-macro by the tonic/prost gRPC stack; no runtime artifact + # and no maintained replacement reachable without an upstream bump. Predates + # this change. Transitive; revisit when tonic/prost drop it. + "RUSTSEC-2026-0173", + # RUSTSEC-2025-0134: rustls-pemfile 2.2.0. Paths: rustls-pemfile -> # pingora-rustls -> pingora-core (direct dep), -> object_store (direct # dep), -> kube-client -> kube (direct dep), and -> axum-server (direct diff --git a/docs/README.md b/docs/README.md index e9211206..6bf0f8b7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,6 +25,7 @@ New here? Read [manual.md](manual.md) for install and CLI, then [configuration.m ### AI gateway - [ai-gateway.md](ai-gateway.md) - providers, routing strategies, guardrails, budgets, streaming. +- [local-inference.md](local-inference.md) - run embeddings (semantic cache) and prompt-injection classify on local ONNX models via the sidecar or in-process. - [ai-lb-benchmark.md](ai-lb-benchmark.md) - P50/P95/P99/P99.9 latency comparison across AI router strategies under skewed load. - [providers.md](providers.md) - the catalog of supported LLM providers. - [scripting.md](scripting.md) - CEL, Lua, JavaScript, and WASM scripting reference. diff --git a/docs/local-inference.md b/docs/local-inference.md new file mode 100644 index 00000000..2cb1c15e --- /dev/null +++ b/docs/local-inference.md @@ -0,0 +1,193 @@ +# Local inference (embeddings and prompt-injection classify) +*Last modified: 2026-06-08* + +SBproxy can run two AI-gateway features on local ONNX models instead of paid +APIs: + +- The **embedding semantic cache** vectorizes prompts to serve near-duplicate + requests from cache. +- **Prompt-injection v2** classifies prompts for injection attempts. + +Running these locally means no per-call API cost, no prompt egress (the prompt +never leaves your network), low loopback latency, and air-gap support. Models +run on a pure-Rust engine (`tract`), so there is no Python and no native +ONNX Runtime install. + +There are two ways to run local inference: + +- **Sidecar (recommended).** A small co-located process holds the model. A bad + or oversized model can only OOM the sidecar, which the proxy restarts; it + never takes the proxy down. +- **In-process (opt-in).** The model loads inside the proxy for a true single + binary. Simpler to deploy, but a model parse runs in the proxy's address + space, so it is gated behind explicit config and a size guard. + +## Models + +| Use | Default model | License | Size | +|---|---|---|---| +| Embeddings | `all-MiniLM-L6-v2` (384-dim) | Apache-2.0 | ~90 MB | +| Prompt-injection classify | `protectai/deberta-v3-base-prompt-injection-v2` | Apache-2.0 | ~70 MB int8 | + +Both are operator-supplied runtime data, not bundled with the binary. Download +them once and point the sidecar (or the in-process config) at the files. + +### Download the models + +```bash +mkdir -p /var/lib/sbproxy/models/minilm /var/lib/sbproxy/models/injection + +# Embedding model (all-MiniLM-L6-v2) +curl -fSL -o /var/lib/sbproxy/models/minilm/model.onnx \ + https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx +curl -fSL -o /var/lib/sbproxy/models/minilm/tokenizer.json \ + https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json + +# Prompt-injection classifier +curl -fSL -o /var/lib/sbproxy/models/injection/model.onnx \ + https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2/resolve/main/onnx/model.onnx +curl -fSL -o /var/lib/sbproxy/models/injection/tokenizer.json \ + https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2/resolve/main/tokenizer.json +``` + +Air-gapped sites: download on a connected host, verify the SHA-256 against the +upstream model card, then copy the files into place. The engine validates a +pinned hash when one is configured, and otherwise trusts the local file. + +## Run the sidecar + +The sidecar binary is `sbproxy-classifier-sidecar`. It serves both `Classify` +and `Embed` over gRPC (TCP or a Unix domain socket). Load whichever models you +need: + +```bash +sbproxy-classifier-sidecar \ + --listen 127.0.0.1:9440 \ + --model prompt-injection=/var/lib/sbproxy/models/injection/model.onnx:/var/lib/sbproxy/models/injection/tokenizer.json \ + --embed-model all-MiniLM-L6-v2=/var/lib/sbproxy/models/minilm/model.onnx:/var/lib/sbproxy/models/minilm/tokenizer.json +``` + +Health and readiness are on the same host; the proxy connects lazily, so the +sidecar does not have to be up before the proxy starts. For a co-located +deployment, use `--listen-uds /run/sbproxy/classifier.sock` instead of +`--listen` to skip the loopback TCP round trip. + +## Enable the local semantic cache + +Point the `semantic_cache` block at the sidecar with `source: sidecar`: + +```yaml +ai: + semantic_cache: + enabled: true + threshold: 0.85 # cosine similarity for a near-duplicate hit + ttl_secs: 3600 + max_entries: 1024 + source: sidecar + sidecar: + endpoint: http://127.0.0.1:9440 + model: all-MiniLM-L6-v2 + timeout_ms: 500 +``` + +On a miss the proxy vectorizes the prompt via the sidecar, scans the cache, and +replays the closest cached response when cosine similarity meets `threshold`. If +the sidecar is unreachable, the lookup is treated as a miss and the request +proceeds to the upstream uncached. The cache never wedges a request. + +The default `source` is `provider`, which calls an AI provider's `/v1/embeddings` +API. Existing configs are unchanged. + +## Enable first-class ONNX prompt-injection + +Select the sidecar detector in the `prompt_injection_v2` policy: + +```yaml +policies: + - type: prompt_injection_v2 + threshold: 0.8 + action: block + detector: sidecar + detector_config: + endpoint: http://127.0.0.1:9440 + model: prompt-injection + injection_label: INJECTION + timeout_ms: 250 + fail_closed: false # a sidecar outage degrades to "clean" (allow) +``` + +The default detector is `heuristic-v1` (a zero-dependency regex pass). Choosing +`detector: sidecar` runs the ONNX classifier in the sidecar. + +## In-process opt-in + +For a single binary, run either feature in-process. This loads a model into the +proxy address space, so it is gated behind explicit config and a +`max_model_bytes` guard. Prefer the sidecar for isolation. + +Prompt-injection in-process: + +```yaml +policies: + - type: prompt_injection_v2 + threshold: 0.8 + action: block + detector: inprocess + detector_config: + model_path: /var/lib/sbproxy/models/injection/model.onnx + tokenizer_path: /var/lib/sbproxy/models/injection/tokenizer.json + injection_label: INJECTION + max_model_bytes: 209715200 # 200 MB guard +``` + +In-process semantic cache embeddings: + +```yaml +ai: + semantic_cache: + enabled: true + threshold: 0.85 + source: inprocess + inprocess: + model: all-MiniLM-L6-v2 + model_path: /var/lib/sbproxy/models/minilm/model.onnx + tokenizer_path: /var/lib/sbproxy/models/minilm/tokenizer.json + max_model_bytes: 209715200 # 200 MB guard +``` + +The released `sbproxy` binary is built with the `inprocess-embed` feature, so +`source: inprocess` works out of the box. If you build from source without the +default features, add `--features inprocess-embed`; without it, `source: +inprocess` returns a clear error and the cache treats lookups as misses. + +## Metrics and usage tracking + +Local inference and the semantic cache emit `sbproxy_*` metrics, attributed per +tenant where relevant (see [metrics-stability.md](./metrics-stability.md)): + +| Metric | What it tells you | +|---|---| +| `sbproxy_semantic_cache_results_total{tenant,origin,source,result}` | Cache hit / miss / error rate by embedding source | +| `sbproxy_inference_requests_total{kind,backend,model,result}` | Embed and classify call counts | +| `sbproxy_inference_duration_seconds{kind,backend,model}` | Embed and classify latency | +| `sbproxy_ai_tokens_saved_total{tenant,origin,model,kind}` | Tokens a cache hit avoided | +| `sbproxy_ai_cost_saved_micros_total{tenant,origin,model}` | Micro-USD a cache hit avoided | + +The saved-cost metric uses the same cost table as spent cost, so a dashboard can +show spend and savings side by side and they reconcile. Saved cost is the value +the cache delivered, not just its hit rate. + +## Troubleshooting + +- **Cache never hits.** Confirm the sidecar is up and `--embed-model` is loaded + (`sbproxy_inference_requests_total{kind="embed"}` should increment). Lower + `threshold` if near-duplicates are scored just under it. +- **`Embed` returns FAILED_PRECONDITION.** The sidecar has no embedding model + loaded. Start it with `--embed-model`. +- **Classify always allows.** Check the `injection_label` matches the model's + label set, and that `--model` is loaded on the sidecar. +- **Dimension mismatch after a model change.** The cache skips entries with a + different vector length and logs a warning once. Clear the cache (restart) or + let entries age out via `ttl_secs`. +- **In-process load fails fast.** The model exceeds `max_model_bytes`. Raise the + guard or use the sidecar. diff --git a/docs/metrics-stability.md b/docs/metrics-stability.md index 3fe72eef..8ac98988 100644 --- a/docs/metrics-stability.md +++ b/docs/metrics-stability.md @@ -1,6 +1,6 @@ # Metrics stability -*Last modified: 2026-06-05* +*Last modified: 2026-06-08* Naming conventions, stability guarantees, and the full catalogue of metrics emitted by SBproxy. @@ -642,6 +642,91 @@ and `sbproxy_ai_ratelimit_rejected_total`. The label is the resolved tenant (`__default__` for single-tenant deployments) and is run through the cardinality limiter. +### Local inference and semantic cache + +#### `sbproxy_semantic_cache_results_total` + +| Property | Value | +|---|---| +| Type | Counter | +| Stability | **beta** | +| Description | Embedding semantic-cache outcomes, attributed per tenant. | + +**Labels:** + +| Label | Description | Example values | +|---|---|---| +| `tenant` | Tenant id the request was attributed to | `acme`, `__default__` | +| `origin` | Virtual hostname | `api.example.com` | +| `source` | Embedding source that vectorized the prompt | `provider`, `sidecar`, `inprocess` | +| `result` | Lookup outcome | `hit`, `miss`, `error` | + +#### `sbproxy_inference_requests_total` + +| Property | Value | +|---|---| +| Type | Counter | +| Stability | **beta** | +| Description | Local ONNX inference calls (embeddings and classify) and their outcome. | + +**Labels:** + +| Label | Description | Example values | +|---|---|---| +| `kind` | Inference kind | `embed`, `classify` | +| `backend` | Where inference ran | `sidecar`, `inprocess` | +| `model` | Logical model id | `all-MiniLM-L6-v2`, `prompt-injection-v2` | +| `result` | Call outcome | `ok`, `error` | + +#### `sbproxy_inference_duration_seconds` + +| Property | Value | +|---|---| +| Type | Histogram | +| Stability | **beta** | +| Description | Local ONNX inference latency in seconds. | + +**Labels:** + +| Label | Description | Example values | +|---|---|---| +| `kind` | Inference kind | `embed`, `classify` | +| `backend` | Where inference ran | `sidecar`, `inprocess` | +| `model` | Logical model id | `all-MiniLM-L6-v2` | + +#### `sbproxy_ai_tokens_saved_total` + +| Property | Value | +|---|---| +| Type | Counter | +| Stability | **beta** | +| Description | Tokens a semantic-cache hit avoided (the upstream call that did not happen). The value-delivered side of usage tracking, attributed per tenant. | + +**Labels:** + +| Label | Description | Example values | +|---|---|---| +| `tenant` | Tenant id the savings are attributed to | `acme`, `__default__` | +| `origin` | Virtual hostname | `api.example.com` | +| `model` | Model id from the cached response | `gpt-4o`, `claude-sonnet-4-5` | +| `kind` | Token kind | `prompt`, `completion` | + +#### `sbproxy_ai_cost_saved_micros_total` + +| Property | Value | +|---|---| +| Type | Counter | +| Stability | **beta** | +| Description | Micro-USD a semantic-cache hit avoided. Saved cost uses the same cost table as spent cost, so saved and spent reconcile. Attributed per tenant. | + +**Labels:** + +| Label | Description | Example values | +|---|---|---| +| `tenant` | Tenant id the savings are attributed to | `acme`, `__default__` | +| `origin` | Virtual hostname | `api.example.com` | +| `model` | Model id from the cached response | `gpt-4o` | + --- ## Deprecation process diff --git a/docs/observability.md b/docs/observability.md index 01fc3e19..166df24f 100644 --- a/docs/observability.md +++ b/docs/observability.md @@ -1,5 +1,5 @@ # Observability -*Last modified: 2026-06-01* +*Last modified: 2026-06-08* SBproxy ships metrics, logs, and traces from one process. This guide covers the Wave 1 substrate: the SLO catalog, the metric label budget, the log schema and redaction policy, the trace propagation contract, the health endpoints, the dashboards, and the reference Compose stack you can boot in one command. @@ -612,6 +612,36 @@ Span attributes include the OTel semantic conventions (`http.request.method`, `h High-cardinality attributes (`request_id`, `agent_id`) are span attributes only, never Prometheus labels. +### AI gateway spans (gen_ai / OpenInference) + +The AI request span (`ai.request`) follows the OpenTelemetry GenAI semantic conventions (`gen_ai.*`) and dual-emits the OpenInference (`llm.*`) vocabulary, so LLM-native trace backends render a full generation without remapping. Per request it carries: + +| Concept | gen_ai | OpenInference | +|---|---|---| +| Provider / model | `gen_ai.system`, `gen_ai.request.model`, `gen_ai.response.model` | `llm.provider`, `llm.model_name` | +| Tokens (with cache + reasoning split) | `gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`, `gen_ai.usage.cache_read_tokens`, `gen_ai.usage.cache_write_tokens`, `gen_ai.usage.reasoning_tokens` | `llm.token_count.prompt`, `llm.token_count.completion`, `llm.token_count.total` | +| Derived USD cost | `gen_ai.usage.cost` | `llm.usage.total_cost` | +| Pricing catalog revision | `sbproxy.ai.pricing_version` | n/a | +| Failure | `otel.status_code = ERROR` plus `error.type` (`guardrail_blocked`, `rate_limited`, `provider_error`, `content_filter`) | n/a | +| Tenant | `sbproxy.tenant_id` | n/a | + +Token counting happens at the proxy (not trusted from the upstream's self-report), cost is derived from the catalog stamped in `sbproxy.ai.pricing_version`, and the GenAI attribute set is pinned by a conformance test (semconv 1.36.0) so emitted spans cannot silently drift off-spec. + +#### Compatible backends + +OTLP is vendor-agnostic: point `telemetry.endpoint` at any OTLP-compatible backend. These render SBproxy AI spans as LLM trajectories with no custom mapping: + +| Backend | Path | +|---|---| +| Arize Phoenix | OTLP gRPC; reads `gen_ai.*` and OpenInference `llm.*` | +| Langfuse | OTLP; reads `gen_ai.*` | +| Jaeger | OTLP via the OTel Collector | +| Grafana Tempo | OTLP via the Collector | +| Datadog | OTLP via the Datadog Agent or Collector | +| Honeycomb | OTLP direct | + +The reference Compose stack under `deploy/observability/` boots an OTel Collector that fans traces out to these backends. + ### Sampling Wave 1 ships head-based sampling, evaluated at the root span: diff --git a/docs/prompt-injection-v2.md b/docs/prompt-injection-v2.md index cfab4117..5ced0904 100644 --- a/docs/prompt-injection-v2.md +++ b/docs/prompt-injection-v2.md @@ -52,6 +52,31 @@ pre-loads state at startup, not in `detect` itself. |------|-------------| | `heuristic-v1` | Case-insensitive substring matching against the OWASP-LLM-01 vocabulary plus a small "suspicious" cue list. Default; works out of the box. | | `sidecar` | Runs inference in a separate process over gRPC instead of in the proxy. The proxy holds one client; the sidecar (minimal OSS or richer enterprise) implements the shared `InferenceService`. Isolates the model runtime so a bad model cannot exhaust the proxy. Fail-open by default. See [Running detection out of process](#running-detection-out-of-process-the-sidecar-detector). | +| `inprocess` | Runs the ONNX classifier inside the proxy via the pure-Rust tract engine. No second process, but the model parse and inference share the proxy's address space, so it is gated behind an explicit opt-in plus a `max_model_bytes` size guard. Prefer `sidecar` for isolation; use `inprocess` for a single-binary deploy. See [In-process detection](#in-process-detection-the-inprocess-detector). | + +### In-process detection (the `inprocess` detector) + +For a single binary, run the ONNX classifier in the proxy. WOR-612 removed the original in-process detector because an unsandboxed model parse could exhaust the proxy; this brings it back only behind the explicit `detector: inprocess` choice plus a hard `max_model_bytes` cap, and the operator supplies the model and tokenizer paths (OSS ships no weights). + +```yaml +policies: + - type: prompt_injection_v2 + action: block + detector: inprocess + threshold: 0.8 + detector_config: + # On-disk ONNX model + tokenizer the operator provides. + model_path: /var/lib/sbproxy/models/injection/model.onnx + tokenizer_path: /var/lib/sbproxy/models/injection/tokenizer.json + # Label the model emits for an injection verdict (case-insensitive). + injection_label: INJECTION + # Optional class labels indexed by output class; omit to report class_. + # labels: ["SAFE", "INJECTION"] + # Hard upper bound on the model file size in bytes (default 200 MB). + max_model_bytes: 209715200 +``` + +The detector loads the model at config-compile time (the slow path), so a missing or oversized model fails fast at startup rather than on the first request. `detect` then runs cheap tract inference per prompt and maps the top label and score onto the v2 vocabulary using the same cutoffs as the sidecar detector: at or above `threshold` is `injection`, `[0.3, threshold)` is `suspicious`, below `0.3` is `clean`. A non-injection top label is read as confidence the prompt is benign, so its score is inverted. Inference failures fail open (clean); operators who want fail-closed should use the sidecar detector. Because the model loads eagerly, this detector cannot appear in the `examples/` validation sweep; see `docs/local-inference.md` for the full deployment recipe. ## Registering a custom detector diff --git a/e2e/Cargo.toml b/e2e/Cargo.toml index 4e942c09..2acecc29 100644 --- a/e2e/Cargo.toml +++ b/e2e/Cargo.toml @@ -79,6 +79,15 @@ futures-util = "0.3" tonic = "0.12" prost = "0.13" +# WOR-1233: span-arrival e2e. A mock OTLP gRPC collector (built from the +# OTLP proto's TraceServiceServer) receives the proxy's AI spans, and the +# test asserts the gen_ai / OpenInference vocabulary arrives intact. The +# proto + opentelemetry crates are pinned to the same majors already in the +# lock via opentelemetry-otlp, so no duplicate compile. +sbproxy-ai.workspace = true +opentelemetry-proto = { version = "0.27", features = ["gen-tonic", "trace"] } +tokio-stream = { version = "0.1", features = ["net"] } + # Compression decoders for compression.rs e2e assertions. Pinned to the # same versions sbproxy-middleware ships so the test bytes round-trip # through the same encoder/decoder pair the proxy uses at runtime. diff --git a/e2e/tests/otlp_span_arrival_e2e.rs b/e2e/tests/otlp_span_arrival_e2e.rs new file mode 100644 index 00000000..10cf14f9 --- /dev/null +++ b/e2e/tests/otlp_span_arrival_e2e.rs @@ -0,0 +1,134 @@ +//! WOR-1233: span-arrival end to end. +//! +//! Stands up a mock OTLP/gRPC collector (the real OTLP `TraceService` +//! contract), points the proxy's telemetry pipeline at it, emits an AI +//! request span through `sbproxy_ai::tracing_spans`, and asserts the span +//! lands at the collector with the GenAI (`gen_ai.*`) and OpenInference +//! (`llm.*`) vocabulary intact, including the derived USD cost and the +//! error/status fields this PR added. +//! +//! This lives in the e2e crate (not the unit-test gate) because it installs +//! a process-global tracer provider + subscriber and waits for the async +//! batch exporter's scheduled export, which makes it slow (a few seconds) and +//! timing-sensitive. Run it directly: +//! cargo test -p sbproxy-e2e --test otlp_span_arrival_e2e + +use std::collections::HashSet; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use opentelemetry_proto::tonic::collector::trace::v1::{ + trace_service_server::{TraceService, TraceServiceServer}, + ExportTraceServiceRequest, ExportTraceServiceResponse, +}; +use tonic::{Request, Response, Status}; + +/// Mock OTLP/gRPC collector: stores every export request it receives. +#[derive(Clone, Default)] +struct MockCollector { + received: Arc>>, +} + +#[tonic::async_trait] +impl TraceService for MockCollector { + async fn export( + &self, + req: Request, + ) -> Result, Status> { + self.received.lock().unwrap().push(req.into_inner()); + Ok(Response::new(ExportTraceServiceResponse::default())) + } +} + +/// Collect every span name and attribute key seen across all export requests. +fn observed( + received: &Arc>>, +) -> (HashSet, HashSet) { + let mut names = HashSet::new(); + let mut attrs = HashSet::new(); + for req in received.lock().unwrap().iter() { + for rs in &req.resource_spans { + for ss in &rs.scope_spans { + for span in &ss.spans { + names.insert(span.name.clone()); + for kv in &span.attributes { + attrs.insert(kv.key.clone()); + } + } + } + } + } + (names, attrs) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn ai_span_lands_at_otlp_collector_with_genai_vocabulary() { + // 1. Start the mock collector on an ephemeral port. + let received: Arc>> = Arc::new(Mutex::new(Vec::new())); + let collector = MockCollector { + received: received.clone(), + }; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let stream = tokio_stream::wrappers::TcpListenerStream::new(listener); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(TraceServiceServer::new(collector)) + .serve_with_incoming(stream) + .await + .unwrap(); + }); + + // 2. Point the telemetry pipeline at the mock, sampling everything. + let cfg = sbproxy_observe::telemetry::TelemetryConfig { + enabled: true, + endpoint: Some(format!("http://{addr}")), + transport: sbproxy_observe::telemetry::OtlpTransport::Grpc, + sample_rate: Some(1.0), + ..Default::default() + }; + sbproxy_observe::telemetry::init_otlp_pipeline(&cfg).expect("init OTLP pipeline"); + + // 3. Emit an AI request span with the full vocabulary this PR stamps. + { + let span = sbproxy_ai::tracing_spans::ai_request_span("chat", "POST"); + let _entered = span.enter(); + span.record("gen_ai.system", "openai"); + span.record("gen_ai.request.model", "gpt-4o"); + sbproxy_ai::tracing_spans::record_token_usage(&span, 17, 42); + sbproxy_ai::tracing_spans::record_cost_usd(&span, 0.0012); + sbproxy_ai::tracing_spans::record_error( + &span, + sbproxy_ai::tracing_spans::error_type::RATE_LIMITED, + "rate limited", + ); + } + + // 4. Wait for the batch processor's scheduled export to reach the + // collector (default delay is 5s; an explicit shutdown flush does not + // reliably force it across the crate's global-provider boundary), then + // shut the pipeline down cleanly. + tokio::time::sleep(Duration::from_secs(6)).await; + sbproxy_observe::telemetry::shutdown_otlp_pipeline(); + + // 5. Assert the AI span arrived with the GenAI + OpenInference vocabulary. + let (names, attrs) = observed(&received); + assert!( + names.contains("ai.request"), + "the ai.request span must arrive at the OTLP collector; saw names {names:?}" + ); + for key in [ + "gen_ai.system", + "gen_ai.request.model", + "gen_ai.usage.input_tokens", + "gen_ai.usage.output_tokens", + "gen_ai.usage.cost", + "llm.token_count.total", + "error.type", + ] { + assert!( + attrs.contains(key), + "exported span is missing the {key:?} attribute; saw {attrs:?}" + ); + } +} diff --git a/examples/README.md b/examples/README.md index 1056b7d9..fc720e65 100644 --- a/examples/README.md +++ b/examples/README.md @@ -134,6 +134,7 @@ directory (with its `sb.yml` and README). Regenerated from `examples/` on 2026-0 | [robots-llms-txt](robots-llms-txt/) | Demonstrates the Wave 4 text-format policy-graph projections. | | [rsl-licensing](rsl-licensing/) | Demonstrates the Wave 4 policy-graph projections. A single | | [security-headers](security-headers/) | Demonstrates the `security_headers` policy. Every response from the `test.sbproxy.dev` upstream gains the standard browser hardening set: `Strict-Transport | +| [semantic-cache-local](semantic-cache-local/) | The AI gateway's embedding semantic cache, vectorizing prompts on-box via the local classifier sidecar (`source: sidecar`) instead of a paid provider embedding API. No per-call cost, no prompt egress, low loopback latency. | | [semantic-constraint](semantic-constraint/) | A natural-language policy enforced by an LLM-as-judge backend. The `semantic_constraint` policy renders a prompt template against the request envelope | | [service-discovery](service-discovery/) | Demonstrates `service_discovery` on a `proxy` action. Without service discovery, Pingora resolves the upstream hostname once when a connection is esta | | [sessions](sessions/) | The `session` block on `app.local` configures the encrypted cookie used to carry session state across requests. Cookie name is `sb_session`, max age i | diff --git a/examples/observability-stack/README.md b/examples/observability-stack/README.md index ec749c0c..3c8b834d 100644 --- a/examples/observability-stack/README.md +++ b/examples/observability-stack/README.md @@ -16,6 +16,11 @@ Then open: - Prometheus at http://localhost:9090 - Loki ready endpoint at http://localhost:3100/ready - Tempo (queried via Grafana, no first-class UI) +- Arize Phoenix at http://localhost:6006 (LLM-native trace view: SBproxy AI spans render as full generations with tokens, USD cost, latency, and error status) + +The collector applies a cost-aware tail-sampling policy (`tail_sampling` in `otel-collector/config.yaml`): errors and slow traces are always kept, the rest at a configurable base rate. Mirror `keep_over_budget_usd` / `keep_slower_than_secs` from the proxy's telemetry config into that policy. + +Langfuse is a second LLM-native backend. Its v3 self-host needs its own multi-service stack (Postgres, ClickHouse, Redis, object store), so it is not embedded here: run it from its own compose and uncomment the `otlphttp/langfuse` exporter in the collector config, pointing it at the Langfuse OTLP endpoint (`/api/public/otel`). Verify everything is healthy: diff --git a/examples/observability-stack/docker-compose.yml b/examples/observability-stack/docker-compose.yml index 47e8f670..f9b2736f 100644 --- a/examples/observability-stack/docker-compose.yml +++ b/examples/observability-stack/docker-compose.yml @@ -126,6 +126,29 @@ services: networks: - obsnet + # --- Arize Phoenix: LLM-native trace backend (WOR-1227) --- + # + # Renders SBproxy AI spans as full LLM trajectories: the proxy emits the + # OpenTelemetry GenAI (`gen_ai.*`) plus OpenInference (`llm.*`) vocabulary, + # which Phoenix reads natively (tokens, USD cost, latency, error status). + # The collector fans the traces pipeline here in addition to Tempo. UI on + # 6006; Phoenix exposes an OTLP receiver the collector pushes to. + phoenix: + image: arizephoenix/phoenix:7.0.0 + container_name: sbproxy_obs_phoenix + ports: + - "6006:6006" # Phoenix UI + OTLP HTTP ingress + networks: + - obsnet + + # --- Langfuse (OTLP target, run separately) --- + # + # Langfuse v3 self-host needs its own multi-service stack (Postgres + + # ClickHouse + Redis + object store), so it is not embedded here. Run + # Langfuse from its own compose and point the collector's + # `otlphttp/langfuse` exporter at its OTLP endpoint + # (`/api/public/otel`). See the README. + networks: obsnet: driver: bridge diff --git a/examples/observability-stack/otel-collector/config.yaml b/examples/observability-stack/otel-collector/config.yaml index 61cd3891..acff9586 100644 --- a/examples/observability-stack/otel-collector/config.yaml +++ b/examples/observability-stack/otel-collector/config.yaml @@ -27,6 +27,26 @@ processors: limit_percentage: 80 spike_limit_percentage: 25 + # WOR-1230: cost-aware tail sampling. Head sampling at the proxy keeps a + # ratio of roots; this tail policy additionally always-keeps the traces + # operators never want to lose, evaluated once the trace is complete: + # errors (span status ERROR), slow requests, and a probabilistic base + # rate for the rest. Mirror keep_over_budget_usd / keep_slower_than_secs + # from the proxy's telemetry config here. + tail_sampling: + decision_wait: 10s + num_traces: 50000 + policies: + - name: keep-errors + type: status_code + status_code: { status_codes: [ERROR] } + - name: keep-slow + type: latency + latency: { threshold_ms: 2000 } + - name: base-rate + type: probabilistic + probabilistic: { sampling_percentage: 100 } + exporters: # Traces -> Tempo via OTLP gRPC inside the bridge network. otlp/tempo: @@ -49,6 +69,20 @@ exporters: tls: insecure: true + # Traces -> Arize Phoenix (WOR-1227), the LLM-native backend that renders + # SBproxy AI spans as full generations. Phoenix accepts OTLP HTTP on 6006. + otlphttp/phoenix: + endpoint: http://phoenix:6006 + tls: + insecure: true + + # Traces -> Langfuse (run separately, see docker-compose.yml). Uncomment + # and set the endpoint + Basic auth header for your Langfuse instance. + # otlphttp/langfuse: + # endpoint: http://langfuse:3000/api/public/otel + # headers: + # Authorization: "Basic ${LANGFUSE_OTEL_BASIC_AUTH}" + # The collector publishes its own self-metrics on :8888 so Prometheus can # scrape it directly (see prometheus.yml). service: @@ -58,8 +92,11 @@ service: pipelines: traces: receivers: [otlp] - processors: [memory_limiter, batch] - exporters: [otlp/tempo] + # tail_sampling must run before batch so the keep/drop decision sees + # the whole trace. Fans out to Tempo (general traces) and Phoenix + # (LLM-native rendering of the AI spans). + processors: [memory_limiter, tail_sampling, batch] + exporters: [otlp/tempo, otlphttp/phoenix] metrics: receivers: [otlp] processors: [memory_limiter, batch] diff --git a/examples/semantic-cache-local/README.md b/examples/semantic-cache-local/README.md new file mode 100644 index 00000000..91a058bd --- /dev/null +++ b/examples/semantic-cache-local/README.md @@ -0,0 +1,30 @@ +# Local embedding semantic cache + +Serves near-duplicate AI prompts from cache, vectorizing prompts on-box via the +classifier sidecar instead of a paid provider embedding API. No per-call cost, +no prompt egress, low loopback latency. + +## Run + +Start the sidecar with an embedding model (supply your own ONNX model and +tokenizer; the OSS build ships no weights): + +```bash +cargo run -p sbproxy-classifier-sidecar -- \ + --listen 127.0.0.1:9440 \ + --embed-model all-MiniLM-L6-v2=/models/minilm/model.onnx:/models/minilm/tokenizer.json +``` + +Then the proxy: + +```bash +make run CONFIG=examples/semantic-cache-local/sb.yml +``` + +Send two near-duplicate prompts; the second is served from cache (`x-semcache: +HIT`) with no second upstream call. Watch `sbproxy_semantic_cache_results_total`, +`sbproxy_inference_requests_total{kind="embed"}`, and the savings counters +`sbproxy_ai_tokens_saved_total` / `sbproxy_ai_cost_saved_micros_total`. + +See [docs/local-inference.md](../../docs/local-inference.md) for the in-process +option and the full metric set. diff --git a/examples/semantic-cache-local/sb.yml b/examples/semantic-cache-local/sb.yml new file mode 100644 index 00000000..cb1e7541 --- /dev/null +++ b/examples/semantic-cache-local/sb.yml @@ -0,0 +1,45 @@ +# yaml-language-server: $schema=../../schemas/sb-config.schema.json +# Local embedding semantic cache (no provider embedding API). +# +# The AI gateway's semantic cache serves near-duplicate prompts from +# cache. By default it vectorizes prompts via a paid provider embedding +# API; this example points it at the local classifier sidecar instead +# (`source: sidecar`), so prompts are embedded on-box: no per-call cost, +# no prompt egress, low loopback latency. +# +# Run the OSS sidecar with an embedding model (supply your own ONNX +# model + tokenizer; OSS ships no weights): +# cargo run -p sbproxy-classifier-sidecar -- \ +# --listen 127.0.0.1:9440 \ +# --embed-model all-MiniLM-L6-v2=/models/minilm/model.onnx:/models/minilm/tokenizer.json +# +# Run the proxy: +# make run CONFIG=examples/semantic-cache-local/sb.yml +# +# See docs/local-inference.md for the full recipe and the metrics emitted. + +proxy: + http_bind_port: 8080 + +origins: + "ai.local": + action: + type: ai_proxy + providers: + - name: openai + api_key: ${OPENAI_API_KEY} + default_model: gpt-4o-mini + models: + - gpt-4o-mini + routing: round_robin + # Embedding semantic cache, vectorized locally via the sidecar. + semantic_cache: + enabled: true + threshold: 0.85 # cosine similarity for a near-duplicate hit + ttl_secs: 3600 + max_entries: 1024 + source: sidecar + sidecar: + endpoint: http://127.0.0.1:9440 + model: all-MiniLM-L6-v2 + timeout_ms: 500 diff --git a/schemas/sb-config.schema.json b/schemas/sb-config.schema.json index a13dd4b4..60925132 100644 --- a/schemas/sb-config.schema.json +++ b/schemas/sb-config.schema.json @@ -285,7 +285,7 @@ ] }, "http3": { - "description": "Optional HTTP/3 (QUIC) listener configuration.", + "description": "Optional HTTP/3 (QUIC) listener configuration.\n\nTemporarily inert: HTTP/3 is disabled until native QUIC support lands in the underlying proxy engine. The field still parses so existing configs keep loading, but enabling it only logs a warning and does not start a listener.", "default": null, "anyOf": [ { @@ -578,11 +578,11 @@ } }, "Http3Config": { - "description": "HTTP/3 (QUIC) configuration.", + "description": "HTTP/3 (QUIC) configuration.\n\nTemporarily inert: HTTP/3 is disabled until native QUIC support lands in the underlying proxy engine. These fields still parse, but the listener is not started; enabling it logs a warning instead.", "type": "object", "properties": { "enabled": { - "description": "Whether to enable the HTTP/3 (QUIC) listener.", + "description": "Whether to enable the HTTP/3 (QUIC) listener.\n\nCurrently ignored: HTTP/3 is temporarily disabled (see the struct docs). Setting this to `true` logs a warning and starts no listener.", "default": false, "type": "boolean" },