diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index de9993284d..77a0996884 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -33,6 +33,7 @@ def __init__(self): self.azure = False self.api_base = None self.repetition_penalty = None + self.ollama_api_key = None if get_settings().get("LITELLM.DISABLE_AIOHTTP", False): litellm.disable_aiohttp_transport = True @@ -84,7 +85,7 @@ def __init__(self): litellm.api_base = get_settings().ollama.api_base self.api_base = get_settings().ollama.api_base if get_settings().get("OLLAMA.API_KEY", None): - litellm.api_key = get_settings().ollama.api_key + self.ollama_api_key = get_settings().ollama.api_key if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None): self.repetition_penalty = float(get_settings().huggingface.repetition_penalty) if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): @@ -412,6 +413,11 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY: kwargs["api_key"] = litellm.api_key + # Inject Ollama API key per-request instead of globally, to avoid + # polluting litellm.api_key which is shared across all providers. + if self.ollama_api_key and model.startswith(("ollama/", "ollama_chat/")): + kwargs["api_key"] = self.ollama_api_key + # Get completion with automatic streaming detection resp, finish_reason, response_obj = await self._get_completion(**kwargs) diff --git a/tests/unittest/test_litellm_api_key_guard.py b/tests/unittest/test_litellm_api_key_guard.py index b451097ca7..43663077ee 100644 --- a/tests/unittest/test_litellm_api_key_guard.py +++ b/tests/unittest/test_litellm_api_key_guard.py @@ -277,14 +277,13 @@ async def test_xai_key_forwarded_for_non_ollama_model(self, monkeypatch): async def test_ollama_and_groq_coexist(self, monkeypatch): """Verify both Ollama and Groq keys can coexist and be forwarded correctly. - When multiple providers are configured, litellm.api_key gets overwritten - sequentially during __init__. The sentinel guard should still forward - whatever real key is currently in litellm.api_key. + Ollama key is stored on the handler instance (not globally) and only + injected for ollama/ models. Groq key stays in litellm.api_key and is + forwarded for non-Ollama models. """ groq_key = "gsk-groq-key" ollama_key = "ollama-key" - # Simulate: Groq key set first, then Ollama overwrites litellm.api_key mixed_settings = type("Settings", (), { "config": type("Config", (), { "reasoning_effort": None, @@ -320,14 +319,15 @@ async def test_ollama_and_groq_coexist(self, monkeypatch): mock_call.return_value = _mock_response() handler = LiteLLMAIHandler() - # After init, litellm.api_key should be Ollama (last assignment) - assert litellm.api_key == ollama_key + # After init, litellm.api_key should be Groq (Ollama no longer pollutes global) + assert litellm.api_key == groq_key + # Ollama key stored on the instance + assert handler.ollama_api_key == ollama_key - # Call with Ollama model — should get Ollama key + # Call with Ollama model — should get Ollama key (per-request) await handler.chat_completion(model="ollama/mistral", system="sys", user="usr") assert mock_call.call_args[1]["api_key"] == ollama_key - # Call with non-Ollama model — should still forward the key - # (which is Ollama in this case, but the guard correctly allows real keys through) + # Call with non-Ollama model — should get Groq key (from litellm.api_key) await handler.chat_completion(model="gpt-4o", system="sys", user="usr") - assert mock_call.call_args[1]["api_key"] == ollama_key + assert mock_call.call_args[1]["api_key"] == groq_key