Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aigateway_regression_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ func TestLiveAIGatewayRegressionModels(t *testing.T) {
model: "magistral-medium-latest",
},
{
name: "cohere command-r uses compatibility chat completions",
name: "cohere strips provider-prefixed model",
provider: "cohere",
baseURL: "https://api.cohere.com/compatibility",
envKey: "GATEWAY_COHERE_API_KEY",
model: "command-r-08-2024",
model: "cohere/command-a-plus-05-2026",
},
}

Expand Down
1 change: 1 addition & 0 deletions autorouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ var knownProviderPrefixes = map[string]bool{
"azure": true,
"mistral": true,
"deepseek": true,
"cohere": true,
}

func (a *AutoRouter) validateModelSurface(providerName string, model string, apiType APIType) error {
Expand Down
64 changes: 64 additions & 0 deletions autorouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,69 @@ func TestAutoRouter_DeepseekV4StripsProviderPrefixBeforeForwarding(t *testing.T)
}
}

func TestAutoRouter_CohereStripsProviderPrefixBeforeForwarding(t *testing.T) {
var upstreamModel string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode upstream request: %v", err)
}
upstreamModel = body.Model
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"id":"chatcmpl-cohere","model":"command-a-plus-05-2026","choices":[]}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "cohere",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(data, &req); err != nil {
return BodyMetadata{}, nil, err
}
return BodyMetadata{Model: req.Model}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return ParseURL(upstream.URL + "/v1/chat/completions")
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "chatcmpl-cohere", Model: "command-a-plus-05-2026"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string {
return "cohere"
})),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{
"model": "cohere/command-a-plus-05-2026",
"messages": [{"role":"user","content":"Reply with OK and nothing else."}]
}`)))
resp, _, err := router.Forward(context.Background(), req)
if err != nil {
t.Fatalf("Forward() error = %v", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", resp.StatusCode)
}
if upstreamModel != "command-a-plus-05-2026" {
t.Fatalf("upstream model = %q, want command-a-plus-05-2026", upstreamModel)
}
}

func TestAutoRouter_DeepseekReasoningOffDisablesThinking(t *testing.T) {
var upstreamBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -976,6 +1039,7 @@ func TestStripProviderPrefix(t *testing.T) {
{"azure prefix", "azure/gpt-4-deployment", "gpt-4-deployment", true},
{"mistral prefix", "mistral/codestral-2508", "codestral-2508", true},
{"deepseek prefix", "deepseek/deepseek-v4-pro", "deepseek-v4-pro", true},
{"cohere prefix", "cohere/command-a-plus-05-2026", "command-a-plus-05-2026", true},
{"multiple slashes preserved", "openai/gpt-4/turbo", "gpt-4/turbo", true},
{"empty string", "", "", false},
{"slash only - not a provider", "/", "/", false},
Expand Down
8 changes: 4 additions & 4 deletions detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ func DetectProviderFromModel(model string) string {
return ""
}

// Check for explicit provider prefix (e.g., "openai/gpt-4", "anthropic/claude-3-opus")
// Check for explicit provider prefix (e.g., "openai/gpt-4", "anthropic/claude-3-opus").
// knownProviderPrefixes is the single source of truth shared with stripProviderPrefix,
// so a prefix the router strips is always one it can also route by name.
if idx := strings.Index(model, "/"); idx >= 0 {
prefix := model[:idx]
switch prefix {
case "openai", "anthropic", "googleai", "groq", "fireworks", "xai", "perplexity", "bedrock", "azure", "mistral":
if prefix := model[:idx]; knownProviderPrefixes[prefix] {
return prefix
}
}
Expand Down
15 changes: 15 additions & 0 deletions detection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ func TestDetectProviderFromModel(t *testing.T) {
{"bedrock/claude prefix", "bedrock/anthropic.claude-3", "bedrock"},
{"azure/gpt-4 prefix", "azure/gpt-4", "azure"},
{"mistral/codestral prefix", "mistral/codestral-2508", "mistral"},
{"deepseek/deepseek-chat prefix", "deepseek/deepseek-chat", "deepseek"},
{"cohere/command prefix", "cohere/command-a-plus-05-2026", "cohere"},
{"unknown/ prefix returns unknown", "unknown/model", ""},
{"single slash only", "/", ""},
}
Expand All @@ -153,6 +155,19 @@ func TestDetectProviderFromModel(t *testing.T) {
}
}

// Every prefix the router strips (knownProviderPrefixes) must also be routable
// by name via DetectProviderFromModel. Otherwise a direct consumer that sends
// "cohere/<model>" with no X-Provider header gets ErrNoProvider. This guards
// against the two lists drifting apart again (deepseek/cohere previously did).
func TestDetectProviderFromModel_KnownPrefixesAreDetected(t *testing.T) {
for prefix := range knownProviderPrefixes {
model := prefix + "/some-model"
if got := DetectProviderFromModel(model); got != prefix {
t.Errorf("DetectProviderFromModel(%q) = %q, want %q", model, got, prefix)
}
}
}

func TestDefaultProviderDetector(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading