Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ require (
go.opentelemetry.io/otel/trace v1.40.0
)

require github.com/aws/aws-sdk-go-v2 v1.30.3

require (
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
Expand Down
16 changes: 15 additions & 1 deletion intercept/client_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,19 @@ var authHeaders = []string{
"X-Api-Key",
}

// proxyHeaders describe the path the inbound request took to reach
// aibridge. On bridge routes aibridge acts as a client, not a proxy,
// so these headers are not meaningful on the outbound request.
var proxyHeaders = []string{
"X-Forwarded-For",
"X-Forwarded-Host",
"X-Forwarded-Proto",
"X-Forwarded-Port",
"Forwarded",
}

// PrepareClientHeaders returns a copy of the client headers with hop-by-hop,
// transport, and auth headers removed.
// transport, auth, and proxy headers removed.
func PrepareClientHeaders(clientHeaders http.Header) http.Header {
prepared := clientHeaders.Clone()
for _, h := range hopByHopHeaders {
Expand All @@ -49,6 +60,9 @@ func PrepareClientHeaders(clientHeaders http.Header) http.Header {
for _, h := range authHeaders {
prepared.Del(h)
}
for _, h := range proxyHeaders {
prepared.Del(h)
}
return prepared
}

Expand Down
22 changes: 22 additions & 0 deletions intercept/client_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,28 @@ func TestPrepareClientHeaders(t *testing.T) {
assert.Equal(t, "preserved", result.Get("X-Custom"))
})

t.Run("proxy headers are removed", func(t *testing.T) {
t.Parallel()

input := http.Header{
"X-Forwarded-For": {"203.0.113.50"},
"X-Forwarded-Host": {"app.example.com"},
"X-Forwarded-Proto": {"https"},
"X-Forwarded-Port": {"443"},
"Forwarded": {"for=203.0.113.50;proto=https"},
"X-Custom": {"preserved"},
}

result := PrepareClientHeaders(input)

assert.Empty(t, result.Get("X-Forwarded-For"))
assert.Empty(t, result.Get("X-Forwarded-Host"))
assert.Empty(t, result.Get("X-Forwarded-Proto"))
assert.Empty(t, result.Get("X-Forwarded-Port"))
assert.Empty(t, result.Get("Forwarded"))
assert.Equal(t, "preserved", result.Get("X-Custom"))
})

t.Run("multi-value headers are preserved", func(t *testing.T) {
t.Parallel()

Expand Down
165 changes: 165 additions & 0 deletions internal/integrationtest/bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@ package integrationtest
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/aws/aws-sdk-go-v2/aws"
v4signer "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/coder/aibridge"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/fixtures"
Expand Down Expand Up @@ -455,6 +461,151 @@ func TestAWSBedrockIntegration(t *testing.T) {
}
}
})
// SigV4 signs all headers on the outbound Bedrock request. If any header
// is modified in transit (e.g. an egress proxy appending to X-Forwarded-For),
// the signature becomes invalid and AWS rejects the request with:
// 403: "The request signature we calculated does not match the signature
// you provided."
t.Run("SigV4 signed headers", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool)

proxyHeaders := http.Header{
"X-Forwarded-For": {"203.0.113.50, 10.0.0.1"},
"X-Forwarded-Host": {"app.example.com"},
"X-Forwarded-Proto": {"https"},
}

// Credentials used for both the Bedrock config and the mock's
// signature re-verification.
accessKey := "test-access-key"
secretKey := "test-secret-key"
region := "us-west-2"

var signatureValid atomic.Bool

// Mock Bedrock endpoint (simulates AWS). The OnRequest callback
// re-signs the received request using only the declared
// SignedHeaders and stores whether the signatures match.
fixResp := newFixtureResponse(fix)
fixResp.OnRequest = func(r *http.Request, body []byte) {
authHeader := r.Header.Get("Authorization")
// Passthrough requests have no SigV4 auth; skip verification.
if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256") {
return
}
originalSig := extractSigV4Field(authHeader, "Signature=")

// Rebuild the request the way AWS would: keep only
// the declared SignedHeaders.
signedHeaders := strings.Split(extractSigV4Field(authHeader, "SignedHeaders="), ";")
verifyReq := r.Clone(r.Context())
verifyReq.Header.Del("Authorization")
for h := range verifyReq.Header {
if !slices.Contains(signedHeaders, strings.ToLower(h)) {
verifyReq.Header.Del(h)
}
}
// Restore ContentLength: Go's HTTP server parses it
// from the request but does not put it in r.Header;
// the SigV4 signer reads the struct field.
verifyReq.ContentLength = int64(len(body))

// Re-sign with the same credentials, body hash, and
// timestamp. SigV4 derives the signature from all three,
// so any difference means a header was altered in transit.
signingTime, err := time.Parse("20060102T150405Z", verifyReq.Header.Get("X-Amz-Date"))
require.NoError(t, err)
bodyHash := sha256.Sum256(body)
err = v4signer.NewSigner().SignHTTP(
ctx,
aws.Credentials{AccessKeyID: accessKey, SecretAccessKey: secretKey},
verifyReq, hex.EncodeToString(bodyHash[:]),
"bedrock", region, signingTime,
)
require.NoError(t, err)

recomputedSig := extractSigV4Field(verifyReq.Header.Get("Authorization"), "Signature=")
signatureValid.Store(originalSig == recomputedSig)
}
mockBedrock := newMockUpstream(t, ctx, fixResp)
mockBedrock.AllowOverflow = true

// Simulated egress proxy: modifies X-Forwarded-For and
// forwards to mockBedrock, preserving the original Host.
mockEgressProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
r.Header.Set("X-Forwarded-For", xff+", 10.255.0.1")
}

proxyReq, err := http.NewRequestWithContext(r.Context(), r.Method, mockBedrock.URL+r.URL.Path, r.Body)
require.NoError(t, err)
proxyReq.Header = r.Header.Clone()
proxyReq.Host = r.Host // preserve signed Host

resp, err := http.DefaultClient.Do(proxyReq)
require.NoError(t, err)
defer resp.Body.Close()

for k, vs := range resp.Header {
for _, v := range vs {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}))
t.Cleanup(mockEgressProxy.Close)

bCfg := bedrockCfg(mockEgressProxy.URL)
bCfg.AccessKey = accessKey
bCfg.AccessKeySecret = secretKey
bCfg.Region = region

bridgeServer := newBridgeTestServer(t, ctx, mockEgressProxy.URL,
withCustomProvider(provider.NewAnthropic(anthropicCfg(mockEgressProxy.URL, apiKey), bCfg)),
)

// Sends a bridge request through a mock egress proxy that
// mutates X-Forwarded-For, then verifies the SigV4 signature
// still matches at the mock Bedrock endpoint.
t.Run("bridge SigV4 signature valid", func(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "stream", false)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, proxyHeaders)
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)

assert.True(t, signatureValid.Load(),
"SigV4 signature mismatch: a header modified in transit "+
"was included in the signed-headers set")
})

// Passthrough routes use httputil.ReverseProxy, which forwards
// the request as-is without SigV4 signing, so proxy headers
// are safe to include. ReverseProxy sets its own X-Forwarded-*
// headers via SetXForwarded. This verifies they arrive upstream.
t.Run("passthrough proxy sets own forwarded headers", func(t *testing.T) {
resp := bridgeServer.makeRequest(t, http.MethodGet, "/anthropic/v1/models", nil, proxyHeaders)
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)

received := mockBedrock.receivedRequests()
require.NotEmpty(t, received)
last := received[len(received)-1]

assert.NotEmpty(t, last.Header.Get("X-Forwarded-For"),
"passthrough should set X-Forwarded-For via SetXForwarded")
assert.NotEmpty(t, last.Header.Get("X-Forwarded-Host"),
"passthrough should set X-Forwarded-Host via SetXForwarded")
assert.NotEmpty(t, last.Header.Get("X-Forwarded-Proto"),
"passthrough should set X-Forwarded-Proto via SetXForwarded")
})
})
}

func TestOpenAIChatCompletions(t *testing.T) {
Expand Down Expand Up @@ -2090,3 +2241,17 @@ func TestActorHeaders(t *testing.T) {
}
}
}

// extractSigV4Field extracts a named field from an AWS SigV4
// Authorization header value.
func extractSigV4Field(authHeader, prefix string) string {
idx := strings.Index(authHeader, prefix)
if idx == -1 {
return ""
}
val := authHeader[idx+len(prefix):]
if end := strings.IndexByte(val, ','); end != -1 {
val = val[:end]
}
return strings.TrimSpace(val)
}
Loading