diff --git a/mirix/embeddings.py b/mirix/embeddings.py index 91835cf1b..f1d6a0081 100755 --- a/mirix/embeddings.py +++ b/mirix/embeddings.py @@ -328,6 +328,39 @@ async def get_text_embedding(self, text: str) -> List[float]: return await embedding_with_retry(lambda: self._call_api(text)) +class OpenRouterEmbedding(EmbeddingEndpoint): + """EmbeddingEndpoint subclass that adds Bearer auth for OpenRouter.""" + + def __init__(self, api_key: str, **kwargs: Any): + super().__init__(**kwargs) + self._api_key = api_key + + async def _call_api(self, text: str) -> List[float]: + import httpx + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + json_data = {"input": text, "model": self.model_name} + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._base_url}/embeddings", + headers=headers, + json=json_data, + timeout=self._timeout, + ) + + response_json = response.json() + if isinstance(response_json, dict): + try: + return response_json["data"][0]["embedding"] + except (KeyError, IndexError): + raise TypeError(f"Unexpected embedding response: {response_json}") + raise TypeError(f"Unexpected embedding response type: {response_json}") + + class AzureOpenAIEmbedding: def __init__( self, @@ -526,5 +559,14 @@ async def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] ) return model + elif endpoint_type == "openrouter": + api_key = config.api_key or model_settings.openai_api_key + return OpenRouterEmbedding( + api_key=api_key, + model=config.embedding_model, + base_url=config.embedding_endpoint, + user=str(user_id) if user_id else "", + ) + else: raise ValueError(f"Unknown endpoint type {endpoint_type}") diff --git a/mirix/llm_api/llm_client.py b/mirix/llm_api/llm_client.py index 8b51bd107..59917fa49 100644 --- a/mirix/llm_api/llm_client.py +++ b/mirix/llm_api/llm_client.py @@ -48,5 +48,11 @@ def create( return GoogleAIClient( llm_config=llm_config, ) + case "openrouter": + from mirix.llm_api.openrouter_client import OpenRouterClient + + return OpenRouterClient( + llm_config=llm_config, + ) case _: return None diff --git a/mirix/llm_api/openrouter_client.py b/mirix/llm_api/openrouter_client.py new file mode 100644 index 000000000..ea9e9cab1 --- /dev/null +++ b/mirix/llm_api/openrouter_client.py @@ -0,0 +1,71 @@ +from typing import List, Optional + +from mirix.llm_api.openai_client import OpenAIClient +from mirix.log import get_logger +from mirix.schemas.llm_config import LLMConfig +from mirix.schemas.message import Message as PydanticMessage +from mirix.schemas.openai.chat_completion_request import ( + ChatCompletionRequest, + Tool as OpenAITool, + ToolFunctionChoice, + cast_message_to_subtype, +) +from mirix.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall + +logger = get_logger(__name__) + + +class OpenRouterClient(OpenAIClient): + """LLM client for OpenRouter API. + + Inherits from OpenAIClient and overrides behaviour that is + incompatible with non-OpenAI models routed through OpenRouter: + + 1. Skips ``convert_to_structured_output`` (``strict`` mode is + OpenAI-specific and breaks Anthropic / Gemini models). + 2. Defaults ``tool_choice`` to ``"auto"`` instead of ``"required"`` + so that models can emit reasoning text between tool calls. + """ + + async def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + existing_file_uris: Optional[List[str]] = None, + ) -> dict: + use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") + + openai_message_list = [ + cast_message_to_subtype(m.to_openai_dict(use_developer_message=use_developer_message)) + for m in messages + ] + + model = llm_config.model or None + + tool_choice = "required" if tools else None + + if force_tool_call is not None: + tool_choice = ToolFunctionChoice( + type="function", + function=ToolFunctionChoiceFunctionCall(name=force_tool_call), + ) + + data = ChatCompletionRequest( + model=model, + messages=await self.fill_image_content_in_messages(openai_message_list), + tools=([OpenAITool(type="function", function=f) for f in tools] if tools else None), + tool_choice=tool_choice, + user=str(), + max_completion_tokens=llm_config.max_tokens, + temperature=llm_config.temperature, + ) + + if not (data.tools is not None and len(data.tools) > 0): + delattr(data, "tool_choice") + + # Skip convert_to_structured_output entirely — non-OpenAI models + # do not support strict mode / additionalProperties. + + return data.model_dump(exclude_unset=True) diff --git a/mirix/schemas/embedding_config.py b/mirix/schemas/embedding_config.py index 43a2522d7..3d838bbe9 100755 --- a/mirix/schemas/embedding_config.py +++ b/mirix/schemas/embedding_config.py @@ -40,6 +40,7 @@ class EmbeddingConfig(BaseModel): "hugging-face", "mistral", "together", # completions endpoint + "openrouter", ] = Field(..., description="The endpoint type for the model.") embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).") embedding_model: str = Field(..., description="The model for the embedding.") diff --git a/mirix/schemas/llm_config.py b/mirix/schemas/llm_config.py index e3a243cad..8743f6e5e 100755 --- a/mirix/schemas/llm_config.py +++ b/mirix/schemas/llm_config.py @@ -50,6 +50,7 @@ class LLMConfig(BaseModel): "bedrock", "deepseek", "xai", + "openrouter", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") diff --git a/tests/test_openrouter.py b/tests/test_openrouter.py new file mode 100644 index 000000000..12506271c --- /dev/null +++ b/tests/test_openrouter.py @@ -0,0 +1,287 @@ +"""Tests for OpenRouter support: LLM client, embedding client, schema validation.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from mirix.llm_api.llm_client import LLMClient +from mirix.llm_api.openrouter_client import OpenRouterClient +from mirix.embeddings import OpenRouterEmbedding +from mirix.schemas.llm_config import LLMConfig +from mirix.schemas.embedding_config import EmbeddingConfig + + +# --------------------------------------------------------------------------- +# Schema validation +# --------------------------------------------------------------------------- + +class TestSchemaValidation: + def test_llm_config_accepts_openrouter(self): + config = LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + ) + assert config.model_endpoint_type == "openrouter" + + def test_embedding_config_accepts_openrouter(self): + config = EmbeddingConfig( + embedding_model="google/gemini-embedding-001", + embedding_endpoint_type="openrouter", + embedding_endpoint="https://openrouter.ai/api/v1", + embedding_dim=3072, + ) + assert config.embedding_endpoint_type == "openrouter" + + +# --------------------------------------------------------------------------- +# LLMClient factory +# --------------------------------------------------------------------------- + +class TestLLMClientFactory: + def test_creates_openrouter_client(self): + config = LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + ) + client = LLMClient.create(config) + assert isinstance(client, OpenRouterClient) + + def test_openai_still_creates_openai_client(self): + from mirix.llm_api.openai_client import OpenAIClient + config = LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + context_window=128000, + ) + client = LLMClient.create(config) + assert isinstance(client, OpenAIClient) + assert not isinstance(client, OpenRouterClient) + + +# --------------------------------------------------------------------------- +# OpenRouterClient.build_request_data +# --------------------------------------------------------------------------- + +class TestOpenRouterClient: + @pytest.fixture + def llm_config(self): + return LLMConfig( + model="anthropic/claude-haiku-4.5", + model_endpoint_type="openrouter", + model_endpoint="https://openrouter.ai/api/v1", + context_window=128000, + max_tokens=1024, + temperature=0.7, + ) + + @pytest.fixture + def client(self, llm_config): + return OpenRouterClient(llm_config=llm_config) + + @pytest.fixture + def mock_message(self): + msg = MagicMock() + msg.to_openai_dict.return_value = { + "role": "user", + "content": "Hello", + } + return msg + + @pytest.fixture + def sample_tools(self): + return [ + { + "name": "search_memory", + "description": "Search memories", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"], + }, + } + ] + + @pytest.mark.asyncio + async def test_no_strict_mode_in_tools(self, client, llm_config, mock_message, sample_tools): + """OpenRouter must NOT add strict/additionalProperties to tool schemas.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + ) + for tool in data.get("tools", []): + func = tool.get("function", {}) + assert "strict" not in func, "strict should not be set for OpenRouter" + params = func.get("parameters", {}) + assert "additionalProperties" not in params, "additionalProperties should not be set for OpenRouter" + + @pytest.mark.asyncio + async def test_tool_choice_required_with_tools(self, client, llm_config, mock_message, sample_tools): + """tool_choice should be 'required' when tools are provided.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + ) + assert data.get("tool_choice") == "required" + + @pytest.mark.asyncio + async def test_no_tool_choice_without_tools(self, client, llm_config, mock_message): + """tool_choice should be absent when no tools are provided.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=None, + ) + assert "tool_choice" not in data + + @pytest.mark.asyncio + async def test_force_tool_call(self, client, llm_config, mock_message, sample_tools): + """force_tool_call should set tool_choice to specific function.""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + tools=sample_tools, + force_tool_call="search_memory", + ) + tc = data.get("tool_choice", {}) + assert tc.get("type") == "function" + assert tc.get("function", {}).get("name") == "search_memory" + + @pytest.mark.asyncio + async def test_model_passed_through(self, client, llm_config, mock_message): + """Model name should be passed as-is (e.g. 'anthropic/claude-haiku-4.5').""" + data = await client.build_request_data( + messages=[mock_message], + llm_config=llm_config, + ) + assert data["model"] == "anthropic/claude-haiku-4.5" + + @pytest.mark.asyncio + async def test_inherits_request_method(self, client): + """OpenRouterClient should inherit the request() method from OpenAIClient.""" + assert hasattr(client, "request") + assert hasattr(client, "convert_response_to_chat_completion") + + +# --------------------------------------------------------------------------- +# OpenRouterEmbedding +# --------------------------------------------------------------------------- + +class TestOpenRouterEmbedding: + def test_init(self): + emb = OpenRouterEmbedding( + api_key="test-key", + model="google/gemini-embedding-001", + base_url="https://openrouter.ai/api/v1", + user="test-user", + ) + assert emb._api_key == "test-key" + assert emb.model_name == "google/gemini-embedding-001" + + @pytest.mark.asyncio + async def test_call_api_sends_bearer_auth(self): + """Embedding requests must include Bearer auth header.""" + emb = OpenRouterEmbedding( + api_key="sk-test-123", + model="google/gemini-embedding-001", + base_url="https://openrouter.ai/api/v1", + user="user-1", + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } + + with patch("httpx.AsyncClient") as MockClient: + mock_client_instance = AsyncMock() + mock_client_instance.post.return_value = mock_response + MockClient.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await emb._call_api("test text") + + assert result == [0.1, 0.2, 0.3] + + call_args = mock_client_instance.post.call_args + headers = call_args.kwargs.get("headers") or call_args[1].get("headers") + assert headers["Authorization"] == "Bearer sk-test-123" + assert headers["Content-Type"] == "application/json" + + json_data = call_args.kwargs.get("json") or call_args[1].get("json") + assert json_data["model"] == "google/gemini-embedding-001" + assert json_data["input"] == "test text" + + @pytest.mark.asyncio + async def test_call_api_error_handling(self): + """Should raise TypeError on unexpected response format.""" + emb = OpenRouterEmbedding( + api_key="sk-test", + model="test-model", + base_url="https://openrouter.ai/api/v1", + user="user-1", + ) + + mock_response = MagicMock() + mock_response.json.return_value = {"error": "invalid model"} + + with patch("httpx.AsyncClient") as MockClient: + mock_client_instance = AsyncMock() + mock_client_instance.post.return_value = mock_response + MockClient.return_value.__aenter__ = AsyncMock(return_value=mock_client_instance) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + + with pytest.raises(TypeError, match="Unexpected embedding response"): + await emb._call_api("test text") + + +# --------------------------------------------------------------------------- +# OpenAI client comparison — ensure OpenAI path still adds strict mode +# --------------------------------------------------------------------------- + +class TestOpenAIClientNotAffected: + @pytest.mark.asyncio + async def test_openai_client_adds_structured_output(self): + """Verify the original OpenAIClient still converts to structured output + (adds additionalProperties: false to parameters).""" + from mirix.llm_api.openai_client import OpenAIClient + + config = LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + context_window=128000, + max_tokens=1024, + temperature=0.7, + ) + client = OpenAIClient(llm_config=config) + msg = MagicMock() + msg.to_openai_dict.return_value = {"role": "user", "content": "Hi"} + + tools = [{ + "name": "test_fn", + "description": "A test function", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string", "description": "query"}}, + "required": ["q"], + }, + }] + + data = await client.build_request_data( + messages=[msg], llm_config=config, tools=tools, + ) + + # OpenAI client should have additionalProperties (from convert_to_structured_output) + for tool in data.get("tools", []): + params = tool.get("function", {}).get("parameters", {}) + assert params.get("additionalProperties") is False, \ + "OpenAI client should add additionalProperties=False"