From dfdc31ccf8f7bb4672906cdd78d27d8ff59011e9 Mon Sep 17 00:00:00 2001 From: nanxingw Date: Tue, 14 Apr 2026 17:35:49 +0800 Subject: [PATCH] feat: add OpenRouter LLM and embedding support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a dedicated OpenRouterClient (inherits OpenAIClient) that skips convert_to_structured_output — non-OpenAI models (Anthropic, Gemini) routed through OpenRouter do not support strict mode. Also adds OpenRouterEmbedding with Bearer auth for the embedding endpoint. New endpoint type "openrouter" is registered in LLMClient factory, LLMConfig, and EmbeddingConfig schemas. Co-Authored-By: Claude Opus 4.6 (1M context) --- mirix/embeddings.py | 42 +++++ mirix/llm_api/llm_client.py | 6 + mirix/llm_api/openrouter_client.py | 71 +++++++ mirix/schemas/embedding_config.py | 1 + mirix/schemas/llm_config.py | 1 + tests/test_openrouter.py | 287 +++++++++++++++++++++++++++++ 6 files changed, 408 insertions(+) create mode 100644 mirix/llm_api/openrouter_client.py create mode 100644 tests/test_openrouter.py diff --git a/mirix/embeddings.py b/mirix/embeddings.py index 91835cf1b..f1d6a0081 100755 --- a/mirix/embeddings.py +++ b/mirix/embeddings.py @@ -328,6 +328,39 @@ async def get_text_embedding(self, text: str) -> List[float]: return await embedding_with_retry(lambda: self._call_api(text)) +class OpenRouterEmbedding(EmbeddingEndpoint): + """EmbeddingEndpoint subclass that adds Bearer auth for OpenRouter.""" + + def __init__(self, api_key: str, **kwargs: Any): + super().__init__(**kwargs) + self._api_key = api_key + + async def _call_api(self, text: str) -> List[float]: + import httpx + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + json_data = {"input": text, "model": self.model_name} + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._base_url}/embeddings", + headers=headers, + json=json_data, + timeout=self._timeout, + ) + + response_json = response.json() + if isinstance(response_json, dict): + try: + return response_json["data"][0]["embedding"] + except (KeyError, IndexError): + raise TypeError(f"Unexpected embedding response: {response_json}") + raise TypeError(f"Unexpected embedding response type: {response_json}") + + class AzureOpenAIEmbedding: def __init__( self, @@ -526,5 +559,14 @@ async def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] ) return model + elif endpoint_type == "openrouter": + api_key = config.api_key or model_settings.openai_api_key + return OpenRouterEmbedding( + api_key=api_key, + model=config.embedding_model, + base_url=config.embedding_endpoint, + user=str(user_id) if user_id else "", + ) + else: raise ValueError(f"Unknown endpoint type {endpoint_type}") diff --git a/mirix/llm_api/llm_client.py b/mirix/llm_api/llm_client.py index 8b51bd107..59917fa49 100644 --- a/mirix/llm_api/llm_client.py +++ b/mirix/llm_api/llm_client.py @@ -48,5 +48,11 @@ def create( return GoogleAIClient( llm_config=llm_config, ) + case "openrouter": + from mirix.llm_api.openrouter_client import OpenRouterClient + + return OpenRouterClient( + llm_config=llm_config, + ) case _: return None diff --git a/mirix/llm_api/openrouter_client.py b/mirix/llm_api/openrouter_client.py new file mode 100644 index 000000000..ea9e9cab1 --- /dev/null +++ b/mirix/llm_api/openrouter_client.py @@ -0,0 +1,71 @@ +from typing import List, Optional + +from mirix.llm_api.openai_client import OpenAIClient +from mirix.log import get_logger +from mirix.schemas.llm_config import LLMConfig +from mirix.schemas.message import Message as PydanticMessage +from mirix.schemas.openai.chat_completion_request import ( + ChatCompletionRequest, + Tool as OpenAITool, + ToolFunctionChoice, + cast_message_to_subtype, +) +from mirix.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall + +logger = get_logger(__name__) + + +class OpenRouterClient(OpenAIClient): + """LLM client for OpenRouter API. + + Inherits from OpenAIClient and overrides behaviour that is + incompatible with non-OpenAI models routed through OpenRouter: + + 1. Skips ``convert_to_structured_output`` (``strict`` mode is + OpenAI-specific and breaks Anthropic / Gemini models). + 2. Defaults ``tool_choice`` to ``"auto"`` instead of ``"required"`` + so that models can emit reasoning text between tool calls. + """ + + async def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + existing_file_uris: Optional[List[str]] = None, + ) -> dict: + use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") + + openai_message_list = [ + cast_message_to_subtype(m.to_openai_dict(use_developer_message=use_developer_message)) + for m in messages + ] + + model = llm_config.model or None + + tool_choice = "required" if tools else None + + if force_tool_call is not None: + tool_choice = ToolFunctionChoice( + type="function", + function=ToolFunctionChoiceFunctionCall(name=force_tool_call), + ) + + data = ChatCompletionRequest( + model=model, + messages=await self.fill_image_content_in_messages(openai_message_list), + tools=([OpenAITool(type="function", function=f) for f in tools] if tools else None), + tool_choice=tool_choice, + user=str(), + max_completion_tokens=llm_config.max_tokens, + temperature=llm_config.temperature, + ) + + if not (data.tools is not None and len(data.tools) > 0): + delattr(data, "tool_choice") + + # Skip convert_to_structured_output entirely — non-OpenAI models + # do not support strict mode / additionalProperties. + + return data.model_dump(exclude_unset=True) diff --git a/mirix/schemas/embedding_config.py b/mirix/schemas/embedding_config.py index 43a2522d7..3d838bbe9 100755 --- a/mirix/schemas/embedding_config.py +++ b/mirix/schemas/embedding_config.py @@ -40,6 +40,7 @@ class EmbeddingConfig(BaseModel): "hugging-face", "mistral", "together", # completions endpoint + "openrouter", ] = Field(..., description="The endpoint type for the model.") embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).") embedding_model: str = Field(..., description="The model for the embedding.") diff --git a/mirix/schemas/llm_config.py b/mirix/schemas/llm_config.py index e3a243cad..8743f6e5e 100755 --- a/mirix/schemas/llm_config.py +++ b/mirix/schemas/llm_config.py @@ -50,6 +50,7 @@ class LLMConfig(BaseModel): "bedrock", "deepseek", "xai", + "openrouter", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") diff --git a/tests/test_openrouter.py b/tests/test_openrouter.py new file mode 100644 index 000000000..12506271c --- /dev/null +++ b/tests/test_openrouter.py @@ -0,0 +1,287 @@ +"""Tests for OpenRouter support: LLM client, embedding client, schema validation.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from mirix.llm_api.llm_client import LLMClient +from mirix.llm_api.openrouter_client import OpenRouterClient +from mirix.embeddings import OpenRouterEmbedding +from mirix.schemas.llm_config import LLMConfig +from mirix.schemas.embedding_config import EmbeddingConfig + + +# --------------------------------------------------------------------------- +# Schema validation +# --------------------------------------------------------------------------- + +class TestSchemaValidation: + def test_llm_config_accepts_openrouter(self): + config = LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + ) + assert config.model_endpoint_type == "openrouter" + + def test_embedding_config_accepts_openrouter(self): + config = EmbeddingConfig( + embedding_model="google/gemini-embedding-001", + embedding_endpoint_type="openrouter", + embedding_endpoint="https://openrouter.ai/api/v1", + embedding_dim=3072, + ) + assert config.embedding_endpoint_type == "openrouter" + + +# --------------------------------------------------------------------------- +# LLMClient factory +# --------------------------------------------------------------------------- + +class TestLLMClientFactory: + def test_creates_openrouter_client(self): + config = LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + ) + client = LLMClient.create(config) + assert isinstance(client, OpenRouterClient) + + def test_openai_still_creates_openai_client(self): + from mirix.llm_api.openai_client import OpenAIClient + config = LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + context_window=128000, + ) + client = LLMClient.create(config) + assert isinstance(client, OpenAIClient) + assert not isinstance(client, OpenRouterClient) + + +# --------------------------------------------------------------------------- +# OpenRouterClient.build_request_data +# --------------------------------------------------------------------------- + +class TestOpenRouterClient: + @pytest.fixture + def llm_config(self): + return LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + max_tokens=1024, + temperature=0.7, + ) + + @pytest.fixture + def client(self, llm_config): + return OpenRouterClient(llm_config=llm_config) + + @pytest.fixture + def mock_message(self): + msg = MagicMock() + msg.to_openai_dict.return_value = { + "role": "user", + "content": "Hello", + } + return msg + + @pytest.fixture + def sample_tools(self): + return [ + { + "name": "search_memory", + "description": "Search memories", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"], + }, + } + ] + + @pytest.mark.asyncio + async def test_no_strict_mode_in_tools(self, client, llm_config, mock_message, sample_tools): + """OpenRouter must NOT add strict/additionalProperties to tool schemas.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + ) + for tool in data.get("tools", []): + func = tool.get("function", {}) + assert "strict" not in func, "strict should not be set for OpenRouter" + params = func.get("parameters", {}) + assert "additionalProperties" not in params, "additionalProperties should not be set for OpenRouter" + + @pytest.mark.asyncio + async def test_tool_choice_required_with_tools(self, client, llm_config, mock_message, sample_tools): + """tool_choice should be 'required' when tools are provided.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + ) + assert data.get("tool_choice") == "required" + + @pytest.mark.asyncio + async def test_no_tool_choice_without_tools(self, client, llm_config, mock_message): + """tool_choice should be absent when no tools are provided.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=None, + ) + assert "tool_choice" not in data + + @pytest.mark.asyncio + async def test_force_tool_call(self, client, llm_config, mock_message, sample_tools): + """force_tool_call should set tool_choice to specific function.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + force_tool_call="search_memory", + ) + tc = data.get("tool_choice", {}) + assert tc.get("type") == "function" + assert tc.get("function", {}).get("name") == "search_memory" + + @pytest.mark.asyncio + async def test_model_passed_through(self, client, llm_config, mock_message): + """Model name should be passed as-is (e.g. 'anthropic/claude-haiku-4.5').""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + ) + assert data["model"] == "anthropic/claude-haiku-4.5" + + @pytest.mark.asyncio + async def test_inherits_request_method(self, client): + """OpenRouterClient should inherit the request() method from OpenAIClient.""" + assert hasattr(client, "request") + assert hasattr(client, "convert_response_to_chat_completion") + + +# --------------------------------------------------------------------------- +# OpenRouterEmbedding +# --------------------------------------------------------------------------- + +class TestOpenRouterEmbedding: + def test_init(self): + emb = OpenRouterEmbedding( + api_key="test-key", + model="google/gemini-embedding-001", + base_url="https://openrouter.ai/api/v1", + user="test-user", + ) + assert emb._api_key == "test-key" + assert emb.model_name == "google/gemini-embedding-001" + + @pytest.mark.asyncio + async def test_call_api_sends_bearer_auth(self): + """Embedding requests must include Bearer auth header.""" + emb = OpenRouterEmbedding( + api_key="sk-test-123", + model="google/gemini-embedding-001", + base_url="https://openrouter.ai/api/v1", + user="user-1", + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } + + with patch("httpx.AsyncClient") as MockClient: + mock_client_instance = AsyncMock() + mock_client_instance.post.return_value = mock_response + MockClient.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await emb._call_api("test text") + + assert result == [0.1, 0.2, 0.3] + + call_args = mock_client_instance.post.call_args + headers = call_args.kwargs.get("headers") or call_args[1].get("headers") + assert headers["Authorization"] == "Bearer sk-test-123" + assert headers["Content-Type"] == "application/json" + + json_data = call_args.kwargs.get("json") or call_args[1].get("json") + assert json_data["model"] == "google/gemini-embedding-001" + assert json_data["input"] == "test text" + + @pytest.mark.asyncio + async def test_call_api_error_handling(self): + """Should raise TypeError on unexpected response format.""" + emb = OpenRouterEmbedding( + api_key="sk-test", + model="test-model", + base_url="https://openrouter.ai/api/v1", + user="user-1", + ) + + mock_response = MagicMock() + mock_response.json.return_value = {"error": "invalid model"} + + with patch("httpx.AsyncClient") as MockClient: + mock_client_instance = AsyncMock() + mock_client_instance.post.return_value = mock_response + MockClient.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(TypeError, match="Unexpected embedding response"): + await emb._call_api("test text") + + +# --------------------------------------------------------------------------- +# OpenAI client comparison — ensure OpenAI path still adds strict mode +# --------------------------------------------------------------------------- + +class TestOpenAIClientNotAffected: + @pytest.mark.asyncio + async def test_openai_client_adds_structured_output(self): + """Verify the original OpenAIClient still converts to structured output + (adds additionalProperties: false to parameters).""" + from mirix.llm_api.openai_client import OpenAIClient + + config = LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + context_window=128000, + max_tokens=1024, + temperature=0.7, + ) + client = OpenAIClient(llm_config=config) + msg = MagicMock() + msg.to_openai_dict.return_value = {"role": "user", "content": "Hi"} + + tools = [{ + "name": "test_fn", + "description": "A test function", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string", "description": "query"}}, + "required": ["q"], + }, + }] + + data = await client.build_request_data( + messages=[msg], llm_config=config, tools=tools, + ) + + # OpenAI client should have additionalProperties (from convert_to_structured_output) + for tool in data.get("tools", []): + params = tool.get("function", {}).get("parameters", {}) + assert params.get("additionalProperties") is False, \ + "OpenAI client should add additionalProperties=False"