From 8d8a851c0166112d97ff1e96a619502edf806e9f Mon Sep 17 00:00:00 2001 From: Parteek Singh Date: Mon, 15 Jun 2026 09:31:19 -0700 Subject: [PATCH 1/2] Fix Cohere provider-prefixed routing --- aigateway_regression_live_test.go | 4 +- autorouter.go | 1 + autorouter_test.go | 64 +++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/aigateway_regression_live_test.go b/aigateway_regression_live_test.go index a7af4c4..25e4df4 100644 --- a/aigateway_regression_live_test.go +++ b/aigateway_regression_live_test.go @@ -44,11 +44,11 @@ func TestLiveAIGatewayRegressionModels(t *testing.T) { model: "magistral-medium-latest", }, { - name: "cohere command-r uses compatibility chat completions", + name: "cohere strips provider-prefixed model", provider: "cohere", baseURL: "https://api.cohere.com/compatibility", envKey: "GATEWAY_COHERE_API_KEY", - model: "command-r-08-2024", + model: "cohere/command-a-plus-05-2026", }, } diff --git a/autorouter.go b/autorouter.go index c528ac2..d6f7e93 100644 --- a/autorouter.go +++ b/autorouter.go @@ -777,6 +777,7 @@ var knownProviderPrefixes = map[string]bool{ "azure": true, "mistral": true, "deepseek": true, + "cohere": true, } func (a *AutoRouter) validateModelSurface(providerName string, model string, apiType APIType) error { diff --git a/autorouter_test.go b/autorouter_test.go index 5045556..bcb7cff 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -347,6 +347,69 @@ func TestAutoRouter_DeepseekV4StripsProviderPrefixBeforeForwarding(t *testing.T) } } +func TestAutoRouter_CohereStripsProviderPrefixBeforeForwarding(t *testing.T) { + var upstreamModel string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body struct { + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + upstreamModel = body.Model + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-cohere","model":"command-a-plus-05-2026","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "cohere", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(data, &req); err != nil { + return BodyMetadata{}, nil, err + } + return BodyMetadata{Model: req.Model}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl-cohere", Model: "command-a-plus-05-2026"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "cohere" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{ + "model": "cohere/command-a-plus-05-2026", + "messages": [{"role":"user","content":"Reply with OK and nothing else."}] + }`))) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", resp.StatusCode) + } + if upstreamModel != "command-a-plus-05-2026" { + t.Fatalf("upstream model = %q, want command-a-plus-05-2026", upstreamModel) + } +} + func TestAutoRouter_DeepseekReasoningOffDisablesThinking(t *testing.T) { var upstreamBody map[string]any upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -976,6 +1039,7 @@ func TestStripProviderPrefix(t *testing.T) { {"azure prefix", "azure/gpt-4-deployment", "gpt-4-deployment", true}, {"mistral prefix", "mistral/codestral-2508", "codestral-2508", true}, {"deepseek prefix", "deepseek/deepseek-v4-pro", "deepseek-v4-pro", true}, + {"cohere prefix", "cohere/command-a-plus-05-2026", "command-a-plus-05-2026", true}, {"multiple slashes preserved", "openai/gpt-4/turbo", "gpt-4/turbo", true}, {"empty string", "", "", false}, {"slash only - not a provider", "/", "/", false}, From 303a325f59a7bcab3794c94527a2b02a366ff663 Mon Sep 17 00:00:00 2001 From: Parteek Singh Date: Wed, 17 Jun 2026 10:42:49 -0700 Subject: [PATCH 2/2] Detect provider from shared prefix list - `DetectProviderFromModel` now reads `knownProviderPrefixes` - Routes `cohere/` and `deepseek/` IDs, not `ErrNoProvider` - Removes the duplicate prefix `switch` that had drifted - Adds test asserting every stripped prefix is routable --- detection.go | 8 ++++---- detection_test.go | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/detection.go b/detection.go index 28609f8..07fd025 100644 --- a/detection.go +++ b/detection.go @@ -80,11 +80,11 @@ func DetectProviderFromModel(model string) string { return "" } - // Check for explicit provider prefix (e.g., "openai/gpt-4", "anthropic/claude-3-opus") + // Check for explicit provider prefix (e.g., "openai/gpt-4", "anthropic/claude-3-opus"). + // knownProviderPrefixes is the single source of truth shared with stripProviderPrefix, + // so a prefix the router strips is always one it can also route by name. if idx := strings.Index(model, "/"); idx >= 0 { - prefix := model[:idx] - switch prefix { - case "openai", "anthropic", "googleai", "groq", "fireworks", "xai", "perplexity", "bedrock", "azure", "mistral": + if prefix := model[:idx]; knownProviderPrefixes[prefix] { return prefix } } diff --git a/detection_test.go b/detection_test.go index 3157aff..9a24ed0 100644 --- a/detection_test.go +++ b/detection_test.go @@ -139,6 +139,8 @@ func TestDetectProviderFromModel(t *testing.T) { {"bedrock/claude prefix", "bedrock/anthropic.claude-3", "bedrock"}, {"azure/gpt-4 prefix", "azure/gpt-4", "azure"}, {"mistral/codestral prefix", "mistral/codestral-2508", "mistral"}, + {"deepseek/deepseek-chat prefix", "deepseek/deepseek-chat", "deepseek"}, + {"cohere/command prefix", "cohere/command-a-plus-05-2026", "cohere"}, {"unknown/ prefix returns unknown", "unknown/model", ""}, {"single slash only", "/", ""}, } @@ -153,6 +155,19 @@ func TestDetectProviderFromModel(t *testing.T) { } } +// Every prefix the router strips (knownProviderPrefixes) must also be routable +// by name via DetectProviderFromModel. Otherwise a direct consumer that sends +// "cohere/" with no X-Provider header gets ErrNoProvider. This guards +// against the two lists drifting apart again (deepseek/cohere previously did). +func TestDetectProviderFromModel_KnownPrefixesAreDetected(t *testing.T) { + for prefix := range knownProviderPrefixes { + model := prefix + "/some-model" + if got := DetectProviderFromModel(model); got != prefix { + t.Errorf("DetectProviderFromModel(%q) = %q, want %q", model, got, prefix) + } + } +} + func TestDefaultProviderDetector(t *testing.T) { tests := []struct { name string