Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.venv/
venv/
pr_agent/settings/.secrets.toml
pr_agent/settings_prod/.secrets.toml
__pycache__
dist/
*.egg-info/
Expand Down
48 changes: 46 additions & 2 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ def __init__(self):

# Models that require streaming
self.streaming_required_models = STREAMING_REQUIRED_MODELS
self.force_streaming_provider = str(
getattr(get_settings().litellm, "force_streaming_custom_llm_provider", "") or ""
).strip().lower()
raw_force_streaming_api_base_substrings = getattr(
get_settings().litellm, "force_streaming_api_base_substrings", []
)
if isinstance(raw_force_streaming_api_base_substrings, (list, tuple, set)):
self.force_streaming_api_base_substrings = [
str(value).strip().lower()
for value in raw_force_streaming_api_base_substrings
if value is not None and str(value).strip()
]
else:
if raw_force_streaming_api_base_substrings:
get_logger().warning(
"LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS must be a list, tuple, or set. "
"Ignoring invalid value."
)
self.force_streaming_api_base_substrings = []

def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
Expand Down Expand Up @@ -395,6 +414,12 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
# Support for custom OpenAI body fields (e.g., Flex Processing)
kwargs = _process_litellm_extra_body(kwargs)

custom_llm_provider = str(
getattr(get_settings().litellm, "custom_llm_provider", "") or ""
).strip().lower()
if custom_llm_provider:
kwargs["custom_llm_provider"] = custom_llm_provider

# Support for Bedrock custom inference profile via model_id
model_id = get_settings().get("litellm.model_id")
if model_id and 'bedrock/' in model:
Expand Down Expand Up @@ -442,9 +467,28 @@ async def _get_completion(self, **kwargs):
Wrapper that automatically handles streaming for required models.
"""
model = kwargs["model"]
if model in self.streaming_required_models:
custom_llm_provider = str(kwargs.get("custom_llm_provider") or "").strip().lower()
api_base_value = kwargs.get("api_base")
api_base = api_base_value.strip().lower() if isinstance(api_base_value, str) else ""
force_streaming = (
bool(self.force_streaming_provider)
and custom_llm_provider == self.force_streaming_provider
and bool(self.force_streaming_api_base_substrings)
and any(substring in api_base for substring in self.force_streaming_api_base_substrings)
)

# Some OpenAI-compatible endpoints can return an empty-string
# finish_reason on non-streaming responses, which LiteLLM rejects during
# response normalization. Streaming avoids that conversion path.
if model in self.streaming_required_models or force_streaming:
kwargs["stream"] = True
get_logger().info(f"Using streaming mode for model {model}")
if force_streaming and model not in self.streaming_required_models:
get_logger().info(
f"Using streaming mode for model {model} "
"due to OpenAI-compatible endpoint compatibility"
)
else:
get_logger().info(f"Using streaming mode for model {model}")
response = await acompletion(**kwargs)
resp, finish_reason = await _handle_streaming_response(response)
# Create MockResponse for streaming since we don't have the full response object
Expand Down
3 changes: 3 additions & 0 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ enable_callbacks = false
success_callback = []
failure_callback = []
service_callback = []
custom_llm_provider = ""
force_streaming_custom_llm_provider = ""
force_streaming_api_base_substrings = []
# model_id = "" # Optional: Custom inference profile ID for Amazon Bedrock

[pr_similar_issue]
Expand Down
306 changes: 306 additions & 0 deletions tests/unittest/test_litellm_custom_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
from unittest.mock import AsyncMock, patch

import pytest

import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler


def create_mock_settings(
custom_llm_provider=None,
force_streaming_custom_llm_provider="openai",
force_streaming_api_base_substrings=None,
):
if force_streaming_api_base_substrings is None:
force_streaming_api_base_substrings = ["snowflakecomputing.com"]

litellm_settings = type("", (), {"get": lambda self, key, default=None: default})()
if custom_llm_provider is not None:
litellm_settings.custom_llm_provider = custom_llm_provider
litellm_settings.force_streaming_custom_llm_provider = force_streaming_custom_llm_provider
litellm_settings.force_streaming_api_base_substrings = force_streaming_api_base_substrings

def get_value(key, default=None):
values = {
"LITELLM.CUSTOM_LLM_PROVIDER": custom_llm_provider,
"litellm.custom_llm_provider": custom_llm_provider,
"LITELLM.FORCE_STREAMING_CUSTOM_LLM_PROVIDER": force_streaming_custom_llm_provider,
"litellm.force_streaming_custom_llm_provider": force_streaming_custom_llm_provider,
"LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS": force_streaming_api_base_substrings,
"litellm.force_streaming_api_base_substrings": force_streaming_api_base_substrings,
}
return values.get(key, default)

return type(
"",
(),
{
"config": type(
"",
(),
{
"ai_timeout": 120,
"custom_reasoning_model": False,
"verbosity_level": 0,
"get": lambda self, key, default=None: default,
},
)(),
"litellm": litellm_settings,
"get": staticmethod(get_value),
},
)()


def create_mock_acompletion_response():
response_payload = {
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}]
}

class MockCompletionResponse(dict):
def dict(self):
return dict(self)

return MockCompletionResponse(response_payload)


@pytest.mark.asyncio
async def test_custom_llm_provider_is_forwarded_without_rewriting_model(monkeypatch):
fake_settings = create_mock_settings(" OpenAI ")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler.chat_completion(
model="claude-sonnet-4-5",
system="test system",
user="test user",
)

call_kwargs = mock_completion.call_args[1]
assert call_kwargs["model"] == "claude-sonnet-4-5"
assert call_kwargs["custom_llm_provider"] == "openai"


@pytest.mark.asyncio
async def test_custom_llm_provider_is_omitted_when_unset(monkeypatch):
fake_settings = create_mock_settings()
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler.chat_completion(
model="claude-sonnet-4-5",
system="test system",
user="test user",
)

call_kwargs = mock_completion.call_args[1]
assert "custom_llm_provider" not in call_kwargs


@pytest.mark.asyncio
async def test_openai_compatible_endpoint_calls_force_streaming(monkeypatch):
fake_settings = create_mock_settings("openai")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with (
patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion,
patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler._handle_streaming_response",
new_callable=AsyncMock,
) as mock_stream_handler,
):
mock_stream_handler.return_value = ("test", "stop")
handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider="openai",
)

call_kwargs = mock_completion.call_args[1]
assert call_kwargs["stream"] is True


@pytest.mark.asyncio
async def test_openai_compatible_endpoint_normalizes_custom_provider_for_streaming(monkeypatch):
fake_settings = create_mock_settings(" OpenAI ")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with (
patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion,
patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler._handle_streaming_response",
new_callable=AsyncMock,
) as mock_stream_handler,
):
mock_stream_handler.return_value = ("test", "stop")
handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider=" OpenAI ",
)

call_kwargs = mock_completion.call_args[1]
assert call_kwargs["stream"] is True


@pytest.mark.asyncio
async def test_openai_compatible_endpoint_ignores_non_string_api_base(monkeypatch):
fake_settings = create_mock_settings("openai")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base=123,
custom_llm_provider="openai",
)

call_kwargs = mock_completion.call_args[1]
assert "stream" not in call_kwargs


@pytest.mark.asyncio
async def test_force_streaming_is_settings_driven(monkeypatch):
fake_settings = create_mock_settings(
"openai",
force_streaming_custom_llm_provider="openai",
force_streaming_api_base_substrings=["example-gateway.local"],
)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider="openai",
)

call_kwargs = mock_completion.call_args[1]
assert "stream" not in call_kwargs


@pytest.mark.asyncio
async def test_force_streaming_requires_non_empty_provider_setting(monkeypatch):
fake_settings = create_mock_settings(
"openai",
force_streaming_custom_llm_provider="",
force_streaming_api_base_substrings=["snowflakecomputing.com"],
)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider="",
)

call_kwargs = mock_completion.call_args[1]
assert "stream" not in call_kwargs


@pytest.mark.asyncio
async def test_force_streaming_ignores_non_collection_substring_setting(monkeypatch):
fake_settings = create_mock_settings(
"openai",
force_streaming_custom_llm_provider="openai",
force_streaming_api_base_substrings="snowflakecomputing.com",
)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()

handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider="openai",
)

call_kwargs = mock_completion.call_args[1]
assert "stream" not in call_kwargs


@pytest.mark.asyncio
async def test_force_streaming_warns_on_invalid_substring_setting(monkeypatch):
fake_settings = create_mock_settings(
"openai",
force_streaming_custom_llm_provider="openai",
force_streaming_api_base_substrings="snowflakecomputing.com",
)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)

with (
patch(
"pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock,
) as mock_completion,
patch("pr_agent.algo.ai_handlers.litellm_ai_handler.get_logger") as mock_logger,
):
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
await handler._get_completion(
model="claude-sonnet-4-5",
messages=[],
timeout=120,
api_base="https://example-account.snowflakecomputing.com/api/v2/cortex/v1",
custom_llm_provider="openai",
)

mock_logger.return_value.warning.assert_called_once_with(
"LITELLM.FORCE_STREAMING_API_BASE_SUBSTRINGS must be a list, tuple, or set. "
"Ignoring invalid value."
)