Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mirix/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,39 @@ async def get_text_embedding(self, text: str) -> List[float]:
return await embedding_with_retry(lambda: self._call_api(text))


class OpenRouterEmbedding(EmbeddingEndpoint):
"""EmbeddingEndpoint subclass that adds Bearer auth for OpenRouter."""

def __init__(self, api_key: str, **kwargs: Any):
super().__init__(**kwargs)
self._api_key = api_key

async def _call_api(self, text: str) -> List[float]:
import httpx

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
}
json_data = {"input": text, "model": self.model_name}

async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._base_url}/embeddings",
headers=headers,
json=json_data,
timeout=self._timeout,
)

response_json = response.json()
if isinstance(response_json, dict):
try:
return response_json["data"][0]["embedding"]
except (KeyError, IndexError):
raise TypeError(f"Unexpected embedding response: {response_json}")
raise TypeError(f"Unexpected embedding response type: {response_json}")


class AzureOpenAIEmbedding:
def __init__(
self,
Expand Down Expand Up @@ -526,5 +559,14 @@ async def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID]
)
return model

elif endpoint_type == "openrouter":
api_key = config.api_key or model_settings.openai_api_key
return OpenRouterEmbedding(
api_key=api_key,
model=config.embedding_model,
base_url=config.embedding_endpoint,
user=str(user_id) if user_id else "",
)

else:
raise ValueError(f"Unknown endpoint type {endpoint_type}")
6 changes: 6 additions & 0 deletions mirix/llm_api/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,11 @@ def create(
return GoogleAIClient(
llm_config=llm_config,
)
case "openrouter":
from mirix.llm_api.openrouter_client import OpenRouterClient

return OpenRouterClient(
llm_config=llm_config,
)
case _:
return None
71 changes: 71 additions & 0 deletions mirix/llm_api/openrouter_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List, Optional

from mirix.llm_api.openai_client import OpenAIClient
from mirix.log import get_logger
from mirix.schemas.llm_config import LLMConfig
from mirix.schemas.message import Message as PydanticMessage
from mirix.schemas.openai.chat_completion_request import (
ChatCompletionRequest,
Tool as OpenAITool,
ToolFunctionChoice,
cast_message_to_subtype,
)
from mirix.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall

logger = get_logger(__name__)


class OpenRouterClient(OpenAIClient):
"""LLM client for OpenRouter API.

Inherits from OpenAIClient and overrides behaviour that is
incompatible with non-OpenAI models routed through OpenRouter:

1. Skips ``convert_to_structured_output`` (``strict`` mode is
OpenAI-specific and breaks Anthropic / Gemini models).
2. Defaults ``tool_choice`` to ``"auto"`` instead of ``"required"``
so that models can emit reasoning text between tool calls.
"""

async def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
existing_file_uris: Optional[List[str]] = None,
) -> dict:
use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3")

openai_message_list = [
cast_message_to_subtype(m.to_openai_dict(use_developer_message=use_developer_message))
for m in messages
]

model = llm_config.model or None

tool_choice = "required" if tools else None

if force_tool_call is not None:
tool_choice = ToolFunctionChoice(
type="function",
function=ToolFunctionChoiceFunctionCall(name=force_tool_call),
)

data = ChatCompletionRequest(
model=model,
messages=await self.fill_image_content_in_messages(openai_message_list),
tools=([OpenAITool(type="function", function=f) for f in tools] if tools else None),
tool_choice=tool_choice,
user=str(),
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
)

if not (data.tools is not None and len(data.tools) > 0):
delattr(data, "tool_choice")

# Skip convert_to_structured_output entirely — non-OpenAI models
# do not support strict mode / additionalProperties.

return data.model_dump(exclude_unset=True)
1 change: 1 addition & 0 deletions mirix/schemas/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class EmbeddingConfig(BaseModel):
"hugging-face",
"mistral",
"together", # completions endpoint
"openrouter",
] = Field(..., description="The endpoint type for the model.")
embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).")
embedding_model: str = Field(..., description="The model for the embedding.")
Expand Down
1 change: 1 addition & 0 deletions mirix/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
"bedrock",
"deepseek",
"xai",
"openrouter",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
Expand Down
Loading
Loading