From 00b68f6d5638d94cf79491c40dbd2853c32fbdfb Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Fri, 24 Apr 2026 00:15:13 +0530 Subject: [PATCH 1/4] feat: add LiteLLM as AI gateway provider --- backend/director/constants.py | 2 + backend/director/llm/__init__.py | 5 +- backend/director/llm/litellm.py | 147 +++++++++++++++++++ backend/requirements.txt | 1 + backend/tests/__init__.py | 0 backend/tests/test_litellm.py | 237 +++++++++++++++++++++++++++++++ 6 files changed, 391 insertions(+), 1 deletion(-) create mode 100644 backend/director/llm/litellm.py create mode 100644 backend/tests/__init__.py create mode 100644 backend/tests/test_litellm.py diff --git a/backend/director/constants.py b/backend/director/constants.py index a398d134..c18fbf74 100644 --- a/backend/director/constants.py +++ b/backend/director/constants.py @@ -21,6 +21,7 @@ class LLMType(str, Enum): ANTHROPIC = "anthropic" GOOGLEAI = "googleai" VIDEODB_PROXY = "videodb_proxy" + LITELLM = "litellm" class EnvPrefix(str, Enum): @@ -29,5 +30,6 @@ class EnvPrefix(str, Enum): OPENAI_ = "OPENAI_" ANTHROPIC_ = "ANTHROPIC_" GOOGLEAI_ = "GOOGLEAI_" + LITELLM_ = "LITELLM_" DOWNLOADS_PATH="director/downloads" diff --git a/backend/director/llm/__init__.py b/backend/director/llm/__init__.py index 71e79c4c..898044fa 100644 --- a/backend/director/llm/__init__.py +++ b/backend/director/llm/__init__.py @@ -5,6 +5,7 @@ from director.llm.openai import OpenAI from director.llm.anthropic import AnthropicAI from director.llm.googleai import GoogleAI +from director.llm.litellm import LiteLLM from director.llm.videodb_proxy import VideoDBProxy @@ -17,7 +18,9 @@ def get_default_llm(): default_llm = os.getenv("DEFAULT_LLM") - if openai or default_llm == LLMType.OPENAI: + if default_llm == LLMType.LITELLM: + return LiteLLM() + elif openai or default_llm == LLMType.OPENAI: return OpenAI() elif anthropic or default_llm == LLMType.ANTHROPIC: return AnthropicAI() diff --git a/backend/director/llm/litellm.py b/backend/director/llm/litellm.py new file mode 100644 index 00000000..97436f40 --- /dev/null +++ b/backend/director/llm/litellm.py @@ -0,0 +1,147 @@ +import json + +from pydantic import Field +from pydantic_settings import SettingsConfigDict + +from director.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus +from director.constants import LLMType, EnvPrefix + + +class LiteLLMConfig(BaseLLMConfig): + """LiteLLM Config. + + Reads from LITELLM_ prefixed environment variables. + Set LITELLM_CHAT_MODEL to any LiteLLM-supported model string + (e.g. anthropic/claude-3-haiku, openai/gpt-4o, bedrock/anthropic.claude-v2). + + API keys are read from standard provider environment variables + automatically (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.). + Optionally set LITELLM_API_KEY to override. + """ + + model_config = SettingsConfigDict( + env_prefix=EnvPrefix.LITELLM_, + extra="ignore", + ) + + llm_type: str = LLMType.LITELLM + api_key: str = "" + api_base: str = "" + chat_model: str = Field(default="openai/gpt-4o") + max_tokens: int = 4096 + + +class LiteLLM(BaseLLM): + def __init__(self, config: LiteLLMConfig = None): + """ + :param config: LiteLLM Config + """ + if config is None: + config = LiteLLMConfig() + super().__init__(config=config) + + def _format_messages(self, messages: list): + """Format messages to OpenAI chat format. + + LiteLLM accepts OpenAI-format messages and translates + them for each provider internally. + """ + formatted_messages = [] + for message in messages: + if message["role"] == "assistant" and message.get("tool_calls"): + formatted_messages.append( + { + "role": message["role"], + "content": message["content"], + "tool_calls": [ + { + "id": tool_call["id"], + "function": { + "name": tool_call["tool"]["name"], + "arguments": json.dumps( + tool_call["tool"]["arguments"] + ), + }, + "type": tool_call["type"], + } + for tool_call in message["tool_calls"] + ], + } + ) + else: + formatted_messages.append(message) + return formatted_messages + + def _format_tools(self, tools: list): + """Format tools to OpenAI function-calling format.""" + formatted_tools = [] + for tool in tools: + formatted_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"], + }, + } + ) + return formatted_tools + + def chat_completions( + self, messages: list, tools: list = [], stop=None, response_format=None + ): + """Get chat completions via LiteLLM. + + Routes to 100+ providers (OpenAI, Anthropic, Azure, Bedrock, etc.) + based on the model string in LITELLM_CHAT_MODEL. + """ + import litellm + + params = { + "model": self.chat_model, + "messages": self._format_messages(messages), + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "stop": stop, + "timeout": self.timeout, + "drop_params": True, + } + + if self.api_key: + params["api_key"] = self.api_key + if self.api_base: + params["api_base"] = self.api_base + if tools: + params["tools"] = self._format_tools(tools) + params["tool_choice"] = "auto" + if response_format: + params["response_format"] = response_format + + try: + response = litellm.completion(**params) + except Exception as e: + print(f"Error: {e}") + return LLMResponse(content=f"Error: {e}") + + return LLMResponse( + content=response.choices[0].message.content or "", + tool_calls=[ + { + "id": tool_call.id, + "tool": { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + }, + "type": tool_call.type, + } + for tool_call in response.choices[0].message.tool_calls + ] + if response.choices[0].message.tool_calls + else [], + finish_reason=response.choices[0].finish_reason, + send_tokens=response.usage.prompt_tokens, + recv_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + status=LLMResponseStatus.SUCCESS, + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index eed7c16c..0c2ecdab 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,6 +4,7 @@ composio_openai==0.5.50 elevenlabs==1.9.0 fal-client===0.5.8 Flask==3.0.3 +litellm>=1.60.0,<2.0.0 Flask-SocketIO==5.3.6 Flask-Cors==4.0.1 openai==1.55.3 diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/test_litellm.py b/backend/tests/test_litellm.py new file mode 100644 index 00000000..0f6f4dc8 --- /dev/null +++ b/backend/tests/test_litellm.py @@ -0,0 +1,237 @@ +"""Tests for the LiteLLM provider.""" + +import json +import types as builtin_types +from unittest import mock + +import pytest + +from director.llm.base import LLMResponse, LLMResponseStatus + + +# --------------------------------------------------------------------------- +# Fake response helpers (matches OpenAI response shape) +# --------------------------------------------------------------------------- + + +class _FnCall: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + +class _ToolCall: + def __init__(self, id, name, arguments): + self.id = id + self.function = _FnCall(name, json.dumps(arguments)) + self.type = "function" + + +class _Msg: + def __init__(self, content="hello", tool_calls=None): + self.content = content + self.tool_calls = tool_calls + + +class _Usage: + def __init__(self, prompt=10, completion=5, total=15): + self.prompt_tokens = prompt + self.completion_tokens = completion + self.total_tokens = total + + +class _Choice: + def __init__(self, content="hello", finish_reason="stop", tool_calls=None): + self.message = _Msg(content=content, tool_calls=tool_calls) + self.finish_reason = finish_reason + + +class _Response: + def __init__(self, content="hello", finish_reason="stop", tool_calls=None): + self.choices = [_Choice(content=content, finish_reason=finish_reason, tool_calls=tool_calls)] + self.usage = _Usage() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _install_fake_litellm(response_content="hello"): + import sys + + fake = builtin_types.ModuleType("litellm") + fake.completion = mock.MagicMock(return_value=_Response(response_content)) + sys.modules["litellm"] = fake + return fake + + +def _uninstall_fake_litellm(): + import sys + + sys.modules.pop("litellm", None) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestLiteLLMChatCompletions: + def setup_method(self): + self.fake = _install_fake_litellm("test response") + + def teardown_method(self): + _uninstall_fake_litellm() + + def _make_llm(self, **overrides): + from director.llm.litellm import LiteLLMConfig, LiteLLM + + defaults = { + "chat_model": "openai/gpt-4o", + "api_key": "test-key", + } + defaults.update(overrides) + config = LiteLLMConfig(**defaults) + return LiteLLM(config=config) + + def test_basic_completion(self): + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + ) + assert isinstance(result, LLMResponse) + assert result.content == "test response" + assert result.status == LLMResponseStatus.SUCCESS + + def test_passes_drop_params(self): + llm = self._make_llm() + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["drop_params"] is True + + def test_passes_model(self): + llm = self._make_llm(chat_model="anthropic/claude-3-haiku") + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["model"] == "anthropic/claude-3-haiku" + + def test_forwards_api_key(self): + llm = self._make_llm(api_key="sk-test") + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["api_key"] == "sk-test" + + def test_omits_api_key_when_empty(self): + llm = self._make_llm(api_key="") + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert "api_key" not in call_kwargs + + def test_forwards_api_base(self): + llm = self._make_llm(api_base="http://localhost:4000") + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["api_base"] == "http://localhost:4000" + + def test_omits_api_base_when_empty(self): + llm = self._make_llm(api_base="") + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert "api_base" not in call_kwargs + + def test_passes_temperature(self): + llm = self._make_llm() + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["temperature"] == llm.temperature + + def test_tool_calls_returned(self): + tc = _ToolCall("tc1", "search", {"query": "test"}) + self.fake.completion.return_value = _Response( + content="", tool_calls=[tc] + ) + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "search"}], + tools=[{"name": "search", "description": "Search", "parameters": {}}], + ) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["tool"]["name"] == "search" + + def test_token_usage_populated(self): + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + ) + assert result.send_tokens == 10 + assert result.recv_tokens == 5 + assert result.total_tokens == 15 + + def test_error_returns_llm_response(self): + self.fake.completion.side_effect = Exception("connection failed") + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + ) + assert "Error" in result.content + assert result.status == LLMResponseStatus.ERROR + + +class TestLiteLLMRegistration: + def setup_method(self): + _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_llm_type_exists(self): + from director.constants import LLMType + + assert hasattr(LLMType, "LITELLM") + assert LLMType.LITELLM == "litellm" + + def test_env_prefix_exists(self): + from director.constants import EnvPrefix + + assert hasattr(EnvPrefix, "LITELLM_") + assert EnvPrefix.LITELLM_ == "LITELLM_" + + def test_get_default_llm_returns_litellm(self): + from director.llm.litellm import LiteLLM + + with mock.patch.dict("os.environ", {"DEFAULT_LLM": "litellm"}, clear=False): + from director.llm import get_default_llm + + llm = get_default_llm() + assert isinstance(llm, LiteLLM) + + +class TestLiteLLMMessageFormatting: + def setup_method(self): + self.fake = _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_tool_call_messages_formatted(self): + from director.llm.litellm import LiteLLMConfig, LiteLLM + + config = LiteLLMConfig(chat_model="openai/gpt-4o", api_key="k") + llm = LiteLLM(config=config) + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "tc1", + "tool": {"name": "search", "arguments": {"q": "test"}}, + "type": "function", + } + ], + } + ] + formatted = llm._format_messages(messages) + assert formatted[0]["tool_calls"][0]["function"]["name"] == "search" + assert "arguments" in formatted[0]["tool_calls"][0]["function"] From 86e21aa4a1798a02120d064aadc592a46ecdb741 Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Fri, 24 Apr 2026 00:19:32 +0530 Subject: [PATCH 2/4] test: expand unit tests to 34 and add comprehensive E2E --- backend/tests/test_litellm.py | 348 +++++++++++++++++++++++++++++++--- 1 file changed, 320 insertions(+), 28 deletions(-) diff --git a/backend/tests/test_litellm.py b/backend/tests/test_litellm.py index 0f6f4dc8..d31bbc2a 100644 --- a/backend/tests/test_litellm.py +++ b/backend/tests/test_litellm.py @@ -47,9 +47,10 @@ def __init__(self, content="hello", finish_reason="stop", tool_calls=None): class _Response: - def __init__(self, content="hello", finish_reason="stop", tool_calls=None): + def __init__(self, content="hello", finish_reason="stop", tool_calls=None, + prompt_tokens=10, completion_tokens=5, total_tokens=15): self.choices = [_Choice(content=content, finish_reason=finish_reason, tool_calls=tool_calls)] - self.usage = _Usage() + self.usage = _Usage(prompt=prompt_tokens, completion=completion_tokens, total=total_tokens) # --------------------------------------------------------------------------- @@ -73,7 +74,7 @@ def _uninstall_fake_litellm(): # --------------------------------------------------------------------------- -# Tests +# Chat completions # --------------------------------------------------------------------------- @@ -95,7 +96,7 @@ def _make_llm(self, **overrides): config = LiteLLMConfig(**defaults) return LiteLLM(config=config) - def test_basic_completion(self): + def test_basic_completion_returns_content(self): llm = self._make_llm() result = llm.chat_completions( messages=[{"role": "user", "content": "hi"}], @@ -104,79 +105,245 @@ def test_basic_completion(self): assert result.content == "test response" assert result.status == LLMResponseStatus.SUCCESS - def test_passes_drop_params(self): + def test_drop_params_always_true(self): llm = self._make_llm() llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert call_kwargs["drop_params"] is True - def test_passes_model(self): + def test_model_forwarded(self): llm = self._make_llm(chat_model="anthropic/claude-3-haiku") llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert call_kwargs["model"] == "anthropic/claude-3-haiku" - def test_forwards_api_key(self): + def test_api_key_forwarded_when_set(self): llm = self._make_llm(api_key="sk-test") llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert call_kwargs["api_key"] == "sk-test" - def test_omits_api_key_when_empty(self): + def test_api_key_omitted_when_empty(self): llm = self._make_llm(api_key="") llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert "api_key" not in call_kwargs - def test_forwards_api_base(self): + def test_api_base_forwarded_when_set(self): llm = self._make_llm(api_base="http://localhost:4000") llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert call_kwargs["api_base"] == "http://localhost:4000" - def test_omits_api_base_when_empty(self): + def test_api_base_omitted_when_empty(self): llm = self._make_llm(api_base="") llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] assert "api_base" not in call_kwargs - def test_passes_temperature(self): + def test_temperature_forwarded(self): + llm = self._make_llm(temperature=0.7) + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["temperature"] == 0.7 + + def test_max_tokens_forwarded(self): + llm = self._make_llm(max_tokens=2048) + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["max_tokens"] == 2048 + + def test_timeout_forwarded(self): + llm = self._make_llm(timeout=60) + llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["timeout"] == 60 + + def test_stop_forwarded(self): + llm = self._make_llm() + llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + stop=["STOP"], + ) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["stop"] == ["STOP"] + + def test_response_format_forwarded(self): + llm = self._make_llm() + rf = {"type": "json_object"} + llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + response_format=rf, + ) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["response_format"] == rf + + def test_top_p_not_sent(self): + """top_p is omitted to avoid conflicts across providers.""" llm = self._make_llm() llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] - assert call_kwargs["temperature"] == llm.temperature + assert "top_p" not in call_kwargs - def test_tool_calls_returned(self): + +# --------------------------------------------------------------------------- +# Tool calling +# --------------------------------------------------------------------------- + + +class TestLiteLLMToolCalling: + def setup_method(self): + self.fake = _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def _make_llm(self): + from director.llm.litellm import LiteLLMConfig, LiteLLM + + config = LiteLLMConfig(chat_model="openai/gpt-4o", api_key="k") + return LiteLLM(config=config) + + def test_tool_calls_parsed_correctly(self): tc = _ToolCall("tc1", "search", {"query": "test"}) - self.fake.completion.return_value = _Response( - content="", tool_calls=[tc] - ) + self.fake.completion.return_value = _Response(content="", tool_calls=[tc]) + llm = self._make_llm() result = llm.chat_completions( messages=[{"role": "user", "content": "search"}], tools=[{"name": "search", "description": "Search", "parameters": {}}], ) assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["id"] == "tc1" + assert result.tool_calls[0]["tool"]["name"] == "search" + assert result.tool_calls[0]["tool"]["arguments"] == {"query": "test"} + assert result.tool_calls[0]["type"] == "function" + + def test_multiple_tool_calls(self): + tc1 = _ToolCall("tc1", "search", {"q": "a"}) + tc2 = _ToolCall("tc2", "fetch", {"url": "http://x"}) + self.fake.completion.return_value = _Response(content="", tool_calls=[tc1, tc2]) + + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "do stuff"}], + tools=[ + {"name": "search", "description": "S", "parameters": {}}, + {"name": "fetch", "description": "F", "parameters": {}}, + ], + ) + assert len(result.tool_calls) == 2 assert result.tool_calls[0]["tool"]["name"] == "search" + assert result.tool_calls[1]["tool"]["name"] == "fetch" - def test_token_usage_populated(self): + def test_no_tool_calls_returns_empty_list(self): llm = self._make_llm() result = llm.chat_completions( messages=[{"role": "user", "content": "hi"}], ) - assert result.send_tokens == 10 - assert result.recv_tokens == 5 - assert result.total_tokens == 15 + assert result.tool_calls == [] - def test_error_returns_llm_response(self): - self.fake.completion.side_effect = Exception("connection failed") + def test_tools_formatted_with_tool_choice_auto(self): llm = self._make_llm() - result = llm.chat_completions( + llm.chat_completions( + messages=[{"role": "user", "content": "hi"}], + tools=[{"name": "t", "description": "d", "parameters": {"type": "object"}}], + ) + call_kwargs = self.fake.completion.call_args[1] + assert call_kwargs["tool_choice"] == "auto" + assert call_kwargs["tools"][0]["type"] == "function" + assert call_kwargs["tools"][0]["function"]["name"] == "t" + + def test_no_tools_omits_tool_choice(self): + llm = self._make_llm() + llm.chat_completions( messages=[{"role": "user", "content": "hi"}], ) + call_kwargs = self.fake.completion.call_args[1] + assert "tools" not in call_kwargs + assert "tool_choice" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Token usage and finish reason +# --------------------------------------------------------------------------- + + +class TestLiteLLMResponseFields: + def setup_method(self): + self.fake = _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_token_counts(self): + self.fake.completion.return_value = _Response( + content="ok", prompt_tokens=100, completion_tokens=50, total_tokens=150 + ) + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + assert result.send_tokens == 100 + assert result.recv_tokens == 50 + assert result.total_tokens == 150 + + def test_finish_reason(self): + self.fake.completion.return_value = _Response( + content="ok", finish_reason="length" + ) + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + assert result.finish_reason == "length" + + def test_none_content_becomes_empty_string(self): + self.fake.completion.return_value = _Response(content=None) + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + assert result.content == "" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestLiteLLMErrorHandling: + def setup_method(self): + self.fake = _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_exception_returns_error_response(self): + self.fake.completion.side_effect = Exception("connection failed") + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) assert "Error" in result.content + assert "connection failed" in result.content assert result.status == LLMResponseStatus.ERROR + def test_error_response_has_zero_tokens(self): + self.fake.completion.side_effect = Exception("fail") + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + assert result.send_tokens == 0 + assert result.recv_tokens == 0 + assert result.total_tokens == 0 + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + class TestLiteLLMRegistration: def setup_method(self): @@ -206,6 +373,25 @@ def test_get_default_llm_returns_litellm(self): llm = get_default_llm() assert isinstance(llm, LiteLLM) + def test_litellm_not_default_when_openai_key_set(self): + """LiteLLM should only be selected when DEFAULT_LLM=litellm, not by key presence.""" + from director.llm.litellm import LiteLLM + + with mock.patch.dict( + "os.environ", + {"OPENAI_API_KEY": "sk-test", "DEFAULT_LLM": ""}, + clear=False, + ): + from director.llm import get_default_llm + + llm = get_default_llm() + assert not isinstance(llm, LiteLLM) + + +# --------------------------------------------------------------------------- +# Message formatting +# --------------------------------------------------------------------------- + class TestLiteLLMMessageFormatting: def setup_method(self): @@ -214,11 +400,23 @@ def setup_method(self): def teardown_method(self): _uninstall_fake_litellm() - def test_tool_call_messages_formatted(self): + def _make_llm(self): from director.llm.litellm import LiteLLMConfig, LiteLLM - config = LiteLLMConfig(chat_model="openai/gpt-4o", api_key="k") - llm = LiteLLM(config=config) + return LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + + def test_regular_messages_pass_through(self): + llm = self._make_llm() + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatted = llm._format_messages(messages) + assert formatted == messages + + def test_tool_call_messages_reformatted(self): + llm = self._make_llm() messages = [ { "role": "assistant", @@ -233,5 +431,99 @@ def test_tool_call_messages_formatted(self): } ] formatted = llm._format_messages(messages) - assert formatted[0]["tool_calls"][0]["function"]["name"] == "search" - assert "arguments" in formatted[0]["tool_calls"][0]["function"] + tc = formatted[0]["tool_calls"][0] + assert tc["function"]["name"] == "search" + assert json.loads(tc["function"]["arguments"]) == {"q": "test"} + assert tc["id"] == "tc1" + assert tc["type"] == "function" + + def test_tool_result_message_passes_through(self): + llm = self._make_llm() + messages = [ + {"role": "tool", "tool_call_id": "tc1", "content": '{"result": "ok"}'} + ] + formatted = llm._format_messages(messages) + assert formatted[0] == messages[0] + + +# --------------------------------------------------------------------------- +# Tool formatting +# --------------------------------------------------------------------------- + + +class TestLiteLLMToolFormatting: + def setup_method(self): + _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_tools_formatted_to_openai_spec(self): + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + tools = [ + { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + } + ] + formatted = llm._format_tools(tools) + assert len(formatted) == 1 + assert formatted[0]["type"] == "function" + assert formatted[0]["function"]["name"] == "get_weather" + assert formatted[0]["function"]["description"] == "Get weather for a city" + assert formatted[0]["function"]["parameters"]["properties"]["city"]["type"] == "string" + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class TestLiteLLMConfig: + def setup_method(self): + _install_fake_litellm() + + def teardown_method(self): + _uninstall_fake_litellm() + + def test_default_config_values(self): + from director.llm.litellm import LiteLLMConfig + + config = LiteLLMConfig() + assert config.llm_type == "litellm" + assert config.chat_model == "openai/gpt-4o" + assert config.max_tokens == 4096 + assert config.api_key == "" + assert config.api_base == "" + + def test_config_reads_from_env(self): + from director.llm.litellm import LiteLLMConfig + + with mock.patch.dict( + "os.environ", + { + "LITELLM_CHAT_MODEL": "anthropic/claude-3-haiku", + "LITELLM_API_KEY": "sk-env-key", + "LITELLM_MAX_TOKENS": "8192", + }, + clear=False, + ): + config = LiteLLMConfig() + assert config.chat_model == "anthropic/claude-3-haiku" + assert config.api_key == "sk-env-key" + assert config.max_tokens == 8192 + + def test_config_no_api_key_required(self): + """Unlike OpenAI/GoogleAI configs, LiteLLM should not require api_key.""" + from director.llm.litellm import LiteLLMConfig, LiteLLM + + config = LiteLLMConfig(chat_model="openai/gpt-4o") + llm = LiteLLM(config=config) + assert llm.api_key == "" From a38e61ab525531c2878252791350f4ba793b5ab2 Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Fri, 24 Apr 2026 00:25:23 +0530 Subject: [PATCH 3/4] fix: forward top_p, defensive handling for None usage and empty tool args --- backend/director/llm/litellm.py | 53 +++++++++++++++++++-------------- backend/tests/test_litellm.py | 49 +++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 26 deletions(-) diff --git a/backend/director/llm/litellm.py b/backend/director/llm/litellm.py index 97436f40..4c9bb2a8 100644 --- a/backend/director/llm/litellm.py +++ b/backend/director/llm/litellm.py @@ -103,6 +103,7 @@ def chat_completions( "messages": self._format_messages(messages), "temperature": self.temperature, "max_tokens": self.max_tokens, + "top_p": self.top_p, "stop": stop, "timeout": self.timeout, "drop_params": True, @@ -120,28 +121,36 @@ def chat_completions( try: response = litellm.completion(**params) + + usage = getattr(response, "usage", None) + tool_calls = [] + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + args_raw = tool_call.function.arguments + try: + arguments = json.loads(args_raw) if args_raw else {} + except (json.JSONDecodeError, TypeError): + arguments = {} + tool_calls.append( + { + "id": tool_call.id, + "tool": { + "name": tool_call.function.name, + "arguments": arguments, + }, + "type": tool_call.type, + } + ) + + return LLMResponse( + content=response.choices[0].message.content or "", + tool_calls=tool_calls, + finish_reason=response.choices[0].finish_reason, + send_tokens=getattr(usage, "prompt_tokens", 0) or 0, + recv_tokens=getattr(usage, "completion_tokens", 0) or 0, + total_tokens=getattr(usage, "total_tokens", 0) or 0, + status=LLMResponseStatus.SUCCESS, + ) except Exception as e: print(f"Error: {e}") return LLMResponse(content=f"Error: {e}") - - return LLMResponse( - content=response.choices[0].message.content or "", - tool_calls=[ - { - "id": tool_call.id, - "tool": { - "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), - }, - "type": tool_call.type, - } - for tool_call in response.choices[0].message.tool_calls - ] - if response.choices[0].message.tool_calls - else [], - finish_reason=response.choices[0].finish_reason, - send_tokens=response.usage.prompt_tokens, - recv_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - status=LLMResponseStatus.SUCCESS, - ) diff --git a/backend/tests/test_litellm.py b/backend/tests/test_litellm.py index d31bbc2a..a658b7c3 100644 --- a/backend/tests/test_litellm.py +++ b/backend/tests/test_litellm.py @@ -178,12 +178,12 @@ def test_response_format_forwarded(self): call_kwargs = self.fake.completion.call_args[1] assert call_kwargs["response_format"] == rf - def test_top_p_not_sent(self): - """top_p is omitted to avoid conflicts across providers.""" - llm = self._make_llm() + def test_top_p_forwarded(self): + """top_p is forwarded; drop_params=True handles provider conflicts.""" + llm = self._make_llm(top_p=0.95) llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) call_kwargs = self.fake.completion.call_args[1] - assert "top_p" not in call_kwargs + assert call_kwargs["top_p"] == 0.95 # --------------------------------------------------------------------------- @@ -263,6 +263,32 @@ def test_no_tools_omits_tool_choice(self): assert "tools" not in call_kwargs assert "tool_choice" not in call_kwargs + def test_empty_tool_arguments_handled(self): + """Empty string arguments should not crash json.loads.""" + tc = _ToolCall("tc1", "ping", {}) + tc.function.arguments = "" + self.fake.completion.return_value = _Response(content="", tool_calls=[tc]) + + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "ping"}], + tools=[{"name": "ping", "description": "Ping", "parameters": {}}], + ) + assert result.tool_calls[0]["tool"]["arguments"] == {} + + def test_none_tool_arguments_handled(self): + """None arguments should not crash.""" + tc = _ToolCall("tc1", "ping", {}) + tc.function.arguments = None + self.fake.completion.return_value = _Response(content="", tool_calls=[tc]) + + llm = self._make_llm() + result = llm.chat_completions( + messages=[{"role": "user", "content": "ping"}], + tools=[{"name": "ping", "description": "Ping", "parameters": {}}], + ) + assert result.tool_calls[0]["tool"]["arguments"] == {} + # --------------------------------------------------------------------------- # Token usage and finish reason @@ -276,6 +302,21 @@ def setup_method(self): def teardown_method(self): _uninstall_fake_litellm() + def test_none_usage_returns_zero_tokens(self): + """Some providers return None for usage. Should not crash.""" + resp = _Response(content="ok") + resp.usage = None + self.fake.completion.return_value = resp + + from director.llm.litellm import LiteLLMConfig, LiteLLM + + llm = LiteLLM(config=LiteLLMConfig(chat_model="x", api_key="k")) + result = llm.chat_completions(messages=[{"role": "user", "content": "hi"}]) + assert result.send_tokens == 0 + assert result.recv_tokens == 0 + assert result.total_tokens == 0 + assert result.status == LLMResponseStatus.SUCCESS + def test_token_counts(self): self.fake.completion.return_value = _Response( content="ok", prompt_tokens=100, completion_tokens=50, total_tokens=150 From 56bc1fbdfa234830e2b602ad5de1ec5dab5f9605 Mon Sep 17 00:00:00 2001 From: Aarish Alam Date: Thu, 7 May 2026 01:19:28 +0530 Subject: [PATCH 4/4] fix: use logger instead of print, fix mutable default arg, isolate config tests --- backend/director/llm/litellm.py | 7 +++++-- backend/tests/test_litellm.py | 20 +++++++++++--------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/backend/director/llm/litellm.py b/backend/director/llm/litellm.py index 4c9bb2a8..44237516 100644 --- a/backend/director/llm/litellm.py +++ b/backend/director/llm/litellm.py @@ -1,4 +1,5 @@ import json +import logging from pydantic import Field from pydantic_settings import SettingsConfigDict @@ -6,6 +7,8 @@ from director.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus from director.constants import LLMType, EnvPrefix +logger = logging.getLogger(__name__) + class LiteLLMConfig(BaseLLMConfig): """LiteLLM Config. @@ -89,7 +92,7 @@ def _format_tools(self, tools: list): return formatted_tools def chat_completions( - self, messages: list, tools: list = [], stop=None, response_format=None + self, messages: list, tools: list | None = None, stop=None, response_format=None ): """Get chat completions via LiteLLM. @@ -152,5 +155,5 @@ def chat_completions( status=LLMResponseStatus.SUCCESS, ) except Exception as e: - print(f"Error: {e}") + logger.error("LiteLLM completion failed: %s", e) return LLMResponse(content=f"Error: {e}") diff --git a/backend/tests/test_litellm.py b/backend/tests/test_litellm.py index a658b7c3..cf22d035 100644 --- a/backend/tests/test_litellm.py +++ b/backend/tests/test_litellm.py @@ -537,12 +537,13 @@ def teardown_method(self): def test_default_config_values(self): from director.llm.litellm import LiteLLMConfig - config = LiteLLMConfig() - assert config.llm_type == "litellm" - assert config.chat_model == "openai/gpt-4o" - assert config.max_tokens == 4096 - assert config.api_key == "" - assert config.api_base == "" + with mock.patch.dict("os.environ", {}, clear=True): + config = LiteLLMConfig() + assert config.llm_type == "litellm" + assert config.chat_model == "openai/gpt-4o" + assert config.max_tokens == 4096 + assert config.api_key == "" + assert config.api_base == "" def test_config_reads_from_env(self): from director.llm.litellm import LiteLLMConfig @@ -565,6 +566,7 @@ def test_config_no_api_key_required(self): """Unlike OpenAI/GoogleAI configs, LiteLLM should not require api_key.""" from director.llm.litellm import LiteLLMConfig, LiteLLM - config = LiteLLMConfig(chat_model="openai/gpt-4o") - llm = LiteLLM(config=config) - assert llm.api_key == "" + with mock.patch.dict("os.environ", {}, clear=True): + config = LiteLLMConfig(chat_model="openai/gpt-4o") + llm = LiteLLM(config=config) + assert llm.api_key == ""