diff --git a/AGENTS.md b/AGENTS.md index b7a48b1..efa389f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,7 +1,7 @@ # Memory Context -# [otelcontext] recent context, 2026-04-28 1:14am UTC +# [otelcontext] recent context, 2026-04-28 6:43am UTC No previous sessions found. \ No newline at end of file diff --git a/internal/api/tenant_middleware.go b/internal/api/tenant_middleware.go index e6f45e3..dd25c39 100644 --- a/internal/api/tenant_middleware.go +++ b/internal/api/tenant_middleware.go @@ -8,6 +8,7 @@ import ( "github.com/RandomCodeSpace/otelcontext/internal/storage" ) + // TenantHeader is the canonical HTTP header carrying the tenant ID on // read-side (query) requests. Ingest paths resolve tenant separately via gRPC // metadata / OTLP resource attributes and do not go through this middleware. @@ -34,7 +35,11 @@ func TenantMiddleware(cfg *config.Config) func(http.Handler) http.Handler { next.ServeHTTP(w, r) return } - tenant := strings.TrimSpace(r.Header.Get(TenantHeader)) + // SanitizeTenantID returns "" for empty / over-length / control-char + // values so they fall through to the configured default — see + // storage.SanitizeTenantID. Hostile or misconfigured clients cannot + // inject newlines into structured logs or overflow VARCHAR(64). + tenant := storage.SanitizeTenantID(r.Header.Get(TenantHeader)) if tenant == "" { tenant = defaultTenant } diff --git a/internal/api/tenant_middleware_sanitize_test.go b/internal/api/tenant_middleware_sanitize_test.go new file mode 100644 index 0000000..2463ad2 --- /dev/null +++ b/internal/api/tenant_middleware_sanitize_test.go @@ -0,0 +1,43 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/RandomCodeSpace/otelcontext/internal/storage" +) + +// TestTenantMiddleware_SanitizesHeader verifies that an X-Tenant-ID header +// containing control characters or excessive length is rejected back to the +// configured default — preventing log injection on slog structured fields +// and silent VARCHAR truncation at the GORM layer. +func TestTenantMiddleware_SanitizesHeader(t *testing.T) { + mw := TenantMiddleware(nil) // nil cfg → DefaultTenantID fallback + cases := []struct { + header string + wantTenant string + }{ + {"acme", "acme"}, + {"foo\nbar", storage.DefaultTenantID}, // log-injection attempt + {strings.Repeat("x", 200), storage.DefaultTenantID}, // over-length + {"", storage.DefaultTenantID}, // empty + {" ", storage.DefaultTenantID}, // whitespace-only + {"foo\x00bar", storage.DefaultTenantID}, // NUL byte + } + for _, tc := range cases { + t.Run(tc.header, func(t *testing.T) { + var got string + h := mw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + got = storage.TenantFromContext(r.Context()) + })) + r := httptest.NewRequest(http.MethodGet, "/api/foo", nil) + r.Header.Set(TenantHeader, tc.header) + h.ServeHTTP(httptest.NewRecorder(), r) + if got != tc.wantTenant { + t.Errorf("header=%q -> tenant=%q, want %q", tc.header, got, tc.wantTenant) + } + }) + } +} diff --git a/internal/config/selfinstr_guard.go b/internal/config/selfinstr_guard.go new file mode 100644 index 0000000..50035f6 --- /dev/null +++ b/internal/config/selfinstr_guard.go @@ -0,0 +1,100 @@ +package config + +import ( + "log/slog" + "net" + "strings" +) + +// SelfServiceName is the OTel service.name attribute the binary attaches to +// its own self-instrumentation spans. Mirrors the literal in +// main.initTracerProvider — keep the two in sync. +const SelfServiceName = "otelcontext" + +// GuardSelfInstrumentation prevents an amplification loop when +// OTEL_EXPORTER_OTLP_ENDPOINT points at the binary's own gRPC port. Without +// this, every span the OTel SDK emits would re-enter Export, generate more +// spans (one per Export call), and re-enter again — unbounded fan-out. +// +// Strategy: when the configured endpoint resolves to a loopback address, the +// own service name is auto-added to IngestExcludedServices so the ingest +// filter drops self-emitted batches. Operators can still override by setting +// the variable explicitly — the guard only ADDS, never removes. +// +// No-op when self-instrumentation is disabled (empty endpoint) or the +// endpoint is non-loopback (a separate collector, the operator's responsibility). +func (c *Config) GuardSelfInstrumentation() { + if c == nil || c.OTelExporterEndpoint == "" { + return + } + host := hostFromEndpoint(c.OTelExporterEndpoint) + if !isLoopbackHost(host) { + return + } + if hasService(c.IngestExcludedServices, SelfServiceName) { + return + } + if c.IngestExcludedServices == "" { + c.IngestExcludedServices = SelfServiceName + } else { + c.IngestExcludedServices = SelfServiceName + "," + c.IngestExcludedServices + } + slog.Warn("self-instrumentation guard: auto-excluded own service from ingest to break feedback loop", + "endpoint", c.OTelExporterEndpoint, + "self_service", SelfServiceName, + "ingest_excluded_services", c.IngestExcludedServices, + ) +} + +// hostFromEndpoint extracts the host portion of an OTLP endpoint string. +// Tolerates "host", "host:port", and "scheme://host:port" forms — the OTel +// SDK accepts all three. Returns the lowercase host or "" on parse failure. +func hostFromEndpoint(endpoint string) string { + endpoint = strings.TrimSpace(endpoint) + if endpoint == "" { + return "" + } + // Strip scheme if present (e.g. "http://localhost:4317"). + if i := strings.Index(endpoint, "://"); i >= 0 { + endpoint = endpoint[i+3:] + } + // Drop path component if present. + if i := strings.Index(endpoint, "/"); i >= 0 { + endpoint = endpoint[:i] + } + host, _, err := net.SplitHostPort(endpoint) + if err != nil { + // No port — treat the whole thing as host. + host = endpoint + } + return strings.ToLower(strings.Trim(host, "[]")) +} + +// isLoopbackHost reports whether host is one of the well-known loopback +// names or a literal loopback IP. The empty host is treated as loopback +// because OTel SDKs fall back to "localhost" when the endpoint is bare. +func isLoopbackHost(host string) bool { + switch host { + case "", "localhost": + return true + } + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + return false +} + +// hasService reports whether the comma-separated list contains the given +// service name. Whitespace around list entries is tolerated so the same +// helper can validate operator-supplied lists. +func hasService(list, service string) bool { + if list == "" || service == "" { + return false + } + for _, s := range strings.Split(list, ",") { + if strings.TrimSpace(s) == service { + return true + } + } + return false +} diff --git a/internal/config/selfinstr_guard_test.go b/internal/config/selfinstr_guard_test.go new file mode 100644 index 0000000..d28d777 --- /dev/null +++ b/internal/config/selfinstr_guard_test.go @@ -0,0 +1,119 @@ +package config + +import "testing" + +func TestHostFromEndpoint(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"localhost:4317", "localhost"}, + {"127.0.0.1:4317", "127.0.0.1"}, + {"[::1]:4317", "::1"}, + {"http://localhost:4317", "localhost"}, + {"https://collector.example.com:4317", "collector.example.com"}, + {"otelcollector:4317", "otelcollector"}, + {"localhost", "localhost"}, + {" ", ""}, + {"", ""}, + {"http://localhost:4317/v1/traces", "localhost"}, + } + for _, tc := range cases { + if got := hostFromEndpoint(tc.in); got != tc.want { + t.Errorf("hostFromEndpoint(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestIsLoopbackHost(t *testing.T) { + loopback := []string{"", "localhost", "127.0.0.1", "127.1.2.3", "::1"} + for _, h := range loopback { + if !isLoopbackHost(h) { + t.Errorf("isLoopbackHost(%q) = false, want true", h) + } + } + notLoopback := []string{"otelcollector", "10.0.0.1", "collector.example.com", "192.168.1.1"} + for _, h := range notLoopback { + if isLoopbackHost(h) { + t.Errorf("isLoopbackHost(%q) = true, want false", h) + } + } +} + +func TestHasService(t *testing.T) { + t.Parallel() + cases := []struct { + list, service string + want bool + }{ + {"", "otelcontext", false}, + {"otelcontext", "otelcontext", true}, + {"a,b,otelcontext,c", "otelcontext", true}, + {"a, otelcontext , c", "otelcontext", true}, + {"a,b,c", "otelcontext", false}, + {"otelcontextual", "otelcontext", false}, + } + for _, tc := range cases { + if got := hasService(tc.list, tc.service); got != tc.want { + t.Errorf("hasService(%q, %q) = %v, want %v", tc.list, tc.service, got, tc.want) + } + } +} + +func TestGuardSelfInstrumentation(t *testing.T) { + t.Run("NoOpWhenEndpointEmpty", func(t *testing.T) { + c := &Config{IngestExcludedServices: "foo"} + c.GuardSelfInstrumentation() + if c.IngestExcludedServices != "foo" { + t.Fatalf("modified excluded list when endpoint empty: %q", c.IngestExcludedServices) + } + }) + + t.Run("AutoAddsWhenLoopback", func(t *testing.T) { + c := &Config{OTelExporterEndpoint: "localhost:4317"} + c.GuardSelfInstrumentation() + if !hasService(c.IngestExcludedServices, SelfServiceName) { + t.Fatalf("self service not auto-added: %q", c.IngestExcludedServices) + } + }) + + t.Run("PrependsToExistingList", func(t *testing.T) { + c := &Config{ + OTelExporterEndpoint: "127.0.0.1:4317", + IngestExcludedServices: "noisy-svc", + } + c.GuardSelfInstrumentation() + want := SelfServiceName + ",noisy-svc" + if c.IngestExcludedServices != want { + t.Fatalf("got %q, want %q", c.IngestExcludedServices, want) + } + }) + + t.Run("IdempotentWhenAlreadyExcluded", func(t *testing.T) { + c := &Config{ + OTelExporterEndpoint: "[::1]:4317", + IngestExcludedServices: "a," + SelfServiceName + ",b", + } + before := c.IngestExcludedServices + c.GuardSelfInstrumentation() + if c.IngestExcludedServices != before { + t.Fatalf("guard mutated already-excluded list: %q -> %q", before, c.IngestExcludedServices) + } + }) + + t.Run("NoOpForRemoteEndpoint", func(t *testing.T) { + c := &Config{ + OTelExporterEndpoint: "collector.example.com:4317", + IngestExcludedServices: "foo", + } + c.GuardSelfInstrumentation() + if hasService(c.IngestExcludedServices, SelfServiceName) { + t.Fatalf("guard fired on remote endpoint: %q", c.IngestExcludedServices) + } + }) + + t.Run("NilSafe", func(t *testing.T) { + var c *Config + c.GuardSelfInstrumentation() + }) +} diff --git a/internal/ingest/otlp.go b/internal/ingest/otlp.go index 6048c22..0bfac13 100644 --- a/internal/ingest/otlp.go +++ b/internal/ingest/otlp.go @@ -49,8 +49,13 @@ func tenantFromContext(ctx context.Context) string { return storage.TenantFromContext(ctx) } if md, ok := metadata.FromIncomingContext(ctx); ok { - if vals := md.Get(tenantHeader); len(vals) > 0 && vals[0] != "" { - return vals[0] + if vals := md.Get(tenantHeader); len(vals) > 0 { + // Reject empty, over-length, or control-char values via shared + // sanitizer so HTTP and gRPC paths apply identical input-safety + // rules. Empty return falls through to the configured default. + if t := storage.SanitizeTenantID(vals[0]); t != "" { + return t + } } } return "" @@ -84,11 +89,13 @@ func resolveTenant(ctx context.Context, resourceAttrs []*commonpb.KeyValue, fall // tenantFromResource looks for an OTLP resource attribute "tenant.id". // Only consulted when cfg.TrustResourceTenant=true (off by default) — -// see resolveTenant. +// see resolveTenant. The value is run through SanitizeTenantID so a +// compromised SDK cannot smuggle control characters or oversized strings +// even on the trusted-resource path. func tenantFromResource(attrs []*commonpb.KeyValue) string { for _, kv := range attrs { if kv.Key == "tenant.id" { - return kv.Value.GetStringValue() + return storage.SanitizeTenantID(kv.Value.GetStringValue()) } } return "" diff --git a/internal/storage/tenant_ctx.go b/internal/storage/tenant_ctx.go index 1347715..ef7834e 100644 --- a/internal/storage/tenant_ctx.go +++ b/internal/storage/tenant_ctx.go @@ -1,6 +1,46 @@ package storage -import "context" +import ( + "context" + "strings" + "unicode" +) + +// MaxTenantIDLength caps the length of an accepted tenant ID. Tenant IDs are +// stored in a VARCHAR(64) column on every domain row plus propagate into +// structured logs and Prometheus labels. The cap is a defense in depth against +// silent VARCHAR truncation at insert time and unbounded label cardinality +// from a hostile or misconfigured client. +const MaxTenantIDLength = 128 + +// SanitizeTenantID validates and normalizes a tenant ID supplied by an HTTP +// header, gRPC metadata key, or OTLP resource attribute. It returns the empty +// string for any value the caller should reject (and substitute with their +// configured default), so the rejection contract is uniform across transports. +// +// Rejection criteria: +// - empty after TrimSpace +// - length exceeds MaxTenantIDLength after trim +// - contains a Unicode control character (\n, \r, \t, NUL, escape codes) +// +// On the happy path it returns the trimmed value verbatim — no case folding, +// no allowlist, since legitimate tenant IDs may be UUIDs, slugs, or +// organisation names in non-ASCII scripts. +func SanitizeTenantID(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + if len(s) > MaxTenantIDLength { + return "" + } + for _, r := range s { + if unicode.IsControl(r) { + return "" + } + } + return s +} // tenantCtxKey is the private context key used to carry the resolved tenant ID // through an HTTP (or gRPC) request down into the repository layer. diff --git a/internal/storage/tenant_sanitize_test.go b/internal/storage/tenant_sanitize_test.go new file mode 100644 index 0000000..5d2d509 --- /dev/null +++ b/internal/storage/tenant_sanitize_test.go @@ -0,0 +1,77 @@ +package storage + +import ( + "strings" + "testing" +) + +func TestSanitizeTenantID(t *testing.T) { + t.Parallel() + + t.Run("AcceptsValid", func(t *testing.T) { + cases := []string{ + "acme", + "team-alpha", + "customer_42", + "a.b.c", + "550e8400-e29b-41d4-a716-446655440000", // UUID + "münchen", // non-ASCII letters allowed + strings.Repeat("a", MaxTenantIDLength), // exactly at the cap + } + for _, in := range cases { + if got := SanitizeTenantID(in); got != in { + t.Errorf("SanitizeTenantID(%q) = %q, want %q (unchanged)", in, got, in) + } + } + }) + + t.Run("Trims", func(t *testing.T) { + if got := SanitizeTenantID(" acme "); got != "acme" { + t.Errorf("got %q, want %q", got, "acme") + } + }) + + t.Run("RejectsEmpty", func(t *testing.T) { + for _, in := range []string{"", " ", "\t\t"} { + if got := SanitizeTenantID(in); got != "" { + t.Errorf("SanitizeTenantID(%q) = %q, want empty", in, got) + } + } + }) + + t.Run("RejectsControlChars", func(t *testing.T) { + // Newline injection — would corrupt slog/loki structured output if accepted. + // Carriage return — same. NUL — silently truncates strings on some toolchains. + // Tab is also a control character per unicode.IsControl. + cases := []string{ + "foo\nbar", + "foo\rbar", + "foo\x00bar", + "foo\x1bbar", // ESC + "foo\tbar", + } + for _, in := range cases { + if got := SanitizeTenantID(in); got != "" { + t.Errorf("SanitizeTenantID(%q) = %q, want empty (control char rejected)", in, got) + } + } + }) + + t.Run("RejectsOverLength", func(t *testing.T) { + long := strings.Repeat("x", MaxTenantIDLength+1) + if got := SanitizeTenantID(long); got != "" { + t.Errorf("SanitizeTenantID(over-length) = %q, want empty", got) + } + }) + + t.Run("TrimThenLengthCheck", func(t *testing.T) { + // Whitespace doesn't bypass the length cap — the trimmed length is what matters. + // 130 chars total, 128 after trim → still rejected because > MaxTenantIDLength is 128. + // The trim happens FIRST and then we measure len(s). + // "x"*129 + leading space → trim yields "x"*129, > 128 → rejected. + s := " " + strings.Repeat("x", MaxTenantIDLength+1) + if got := SanitizeTenantID(s); got != "" { + t.Errorf("SanitizeTenantID(over-length-after-trim) = %q, want empty", got) + } + }) +} diff --git a/main.go b/main.go index 92e0f6c..b14180e 100644 --- a/main.go +++ b/main.go @@ -135,6 +135,9 @@ func main() { if err := cfg.Validate(); err != nil { fatal("invalid configuration", err) } + // Auto-exclude own service when self-instrumentation points to a loopback + // address (otherwise every span emitted re-enters Export and amplifies). + cfg.GuardSelfInstrumentation() if err := cfg.ValidateDBForEnv(); err != nil { fatal("DB/Env validation", err) } @@ -961,7 +964,7 @@ func initTracerProvider(endpoint string) (*sdktrace.TracerProvider, error) { res, err := sdkresource.New(ctx, sdkresource.WithAttributes( - semconv.ServiceName("otelcontext"), + semconv.ServiceName(config.SelfServiceName), semconv.ServiceVersion(Version), ), )