Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self):
self.azure = False
self.api_base = None
self.repetition_penalty = None
self.ollama_api_key = None

if get_settings().get("LITELLM.DISABLE_AIOHTTP", False):
litellm.disable_aiohttp_transport = True
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(self):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("OLLAMA.API_KEY", None):
litellm.api_key = get_settings().ollama.api_key
self.ollama_api_key = get_settings().ollama.api_key
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
Expand Down Expand Up @@ -412,6 +413,11 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY:
kwargs["api_key"] = litellm.api_key

# Inject Ollama API key per-request instead of globally, to avoid
# polluting litellm.api_key which is shared across all providers.
if self.ollama_api_key and model.startswith(("ollama/", "ollama_chat/")):
kwargs["api_key"] = self.ollama_api_key

# Get completion with automatic streaming detection
resp, finish_reason, response_obj = await self._get_completion(**kwargs)

Expand Down
20 changes: 10 additions & 10 deletions tests/unittest/test_litellm_api_key_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ async def test_xai_key_forwarded_for_non_ollama_model(self, monkeypatch):
async def test_ollama_and_groq_coexist(self, monkeypatch):
"""Verify both Ollama and Groq keys can coexist and be forwarded correctly.

When multiple providers are configured, litellm.api_key gets overwritten
sequentially during __init__. The sentinel guard should still forward
whatever real key is currently in litellm.api_key.
Ollama key is stored on the handler instance (not globally) and only
injected for ollama/ models. Groq key stays in litellm.api_key and is
forwarded for non-Ollama models.
"""
groq_key = "gsk-groq-key"
ollama_key = "ollama-key"

# Simulate: Groq key set first, then Ollama overwrites litellm.api_key
mixed_settings = type("Settings", (), {
"config": type("Config", (), {
"reasoning_effort": None,
Expand Down Expand Up @@ -320,14 +319,15 @@ async def test_ollama_and_groq_coexist(self, monkeypatch):
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()

# After init, litellm.api_key should be Ollama (last assignment)
assert litellm.api_key == ollama_key
# After init, litellm.api_key should be Groq (Ollama no longer pollutes global)
assert litellm.api_key == groq_key
# Ollama key stored on the instance
assert handler.ollama_api_key == ollama_key

# Call with Ollama model — should get Ollama key
# Call with Ollama model — should get Ollama key (per-request)
await handler.chat_completion(model="ollama/mistral", system="sys", user="usr")
assert mock_call.call_args[1]["api_key"] == ollama_key

# Call with non-Ollama model — should still forward the key
# (which is Ollama in this case, but the guard correctly allows real keys through)
# Call with non-Ollama model — should get Groq key (from litellm.api_key)
await handler.chat_completion(model="gpt-4o", system="sys", user="usr")
assert mock_call.call_args[1]["api_key"] == ollama_key
assert mock_call.call_args[1]["api_key"] == groq_key
Loading