diff --git a/responses_api_models/inference_provider/app.py b/responses_api_models/inference_provider/app.py index 7f8fe13b6..943f0f083 100644 --- a/responses_api_models/inference_provider/app.py +++ b/responses_api_models/inference_provider/app.py @@ -20,12 +20,14 @@ For training workloads that require token IDs, use vllm_model instead. """ +import json from asyncio import Semaphore from time import time from typing import Any, Dict from uuid import uuid4 -from fastapi import Request +from aiohttp.client_exceptions import ClientResponseError +from fastapi import HTTPException, Request from pydantic import Field from nemo_gym.base_responses_api_model import ( @@ -59,6 +61,7 @@ class InferenceProviderConfig(BaseResponsesAPIModelConfig): class InferenceProvider(SimpleResponsesAPIModel): config: InferenceProviderConfig + _RETRYABLE_PROVIDER_STATUSES = {429, 500, 502, 503, 504, 520} def model_post_init(self, context): self._client = NeMoGymAsyncOpenAI( @@ -145,9 +148,14 @@ async def chat_completions( if isinstance(content, str): _, remaining_content = self._converter._extract_reasoning_from_content(content) message_dict["content"] = remaining_content - async with self._semaphore: - chat_completion_dict = await self._client.create_chat_completion(**body_dict) + try: + chat_completion_dict = await self._client.create_chat_completion(**body_dict) + except ClientResponseError as e: + normalized_payload = self._build_provider_error_payload(e) + raise HTTPException( + status_code=normalized_payload["provider_status"], detail=normalized_payload + ) from e choice_dict = chat_completion_dict["choices"][0] if self.config.uses_reasoning_parser: @@ -163,6 +171,76 @@ async def chat_completions( return NeMoGymChatCompletion.model_validate(chat_completion_dict) + def _build_provider_error_payload(self, error: ClientResponseError) -> Dict[str, Any]: + provider_status = error.status if error.status else 500 + message = self._extract_provider_error_message(error) + category = self._classify_provider_error(provider_status, message) + return { + "provider_status": provider_status, + "retryable": provider_status in self._RETRYABLE_PROVIDER_STATUSES, + "provider_context": {"base_url": self.config.base_url}, + "model": self.config.model, + "category": category, + "message": message, + } + + def _classify_provider_error(self, status: int, message: str) -> str: + message_lower = message.lower() + if status in {401, 403} or "api key" in message_lower or "auth" in message_lower: + return "authentication" + if status == 404 or ("model" in message_lower and "not found" in message_lower): + return "model_not_found" + if status == 429 or "rate limit" in message_lower: + return "rate_limit" + if status in {400, 422}: + return "request_error" + if status in self._RETRYABLE_PROVIDER_STATUSES: + return "transient_upstream_failure" + return "provider_error" + + def _extract_provider_error_message(self, error: ClientResponseError) -> str: + response_content = getattr(error, "response_content", b"") + if isinstance(response_content, bytes): + response_text = response_content.decode("utf-8", errors="replace").strip() + elif response_content: + response_text = str(response_content).strip() + else: + response_text = str(error) + + parsed_message = response_text + if response_text: + parsed_message = self._extract_error_message_from_response(response_text) + + if parsed_message: + return self._concise(parsed_message) + return "Provider request failed" + + @staticmethod + def _extract_error_message_from_response(response_text: str) -> str: + try: + payload = json.loads(response_text) + except json.JSONDecodeError: + return response_text + + if isinstance(payload, dict): + if isinstance(payload.get("error"), dict): + nested_error = payload["error"] + if nested_error.get("message"): + return str(nested_error["message"]) + + if payload.get("message"): + return str(payload["message"]) + + if payload.get("detail"): + return str(payload["detail"]) + + return response_text + + @staticmethod + def _concise(message: str) -> str: + compact = " ".join(message.strip().split()) + return compact if len(compact) <= 200 else compact[:197] + "..." + if __name__ == "__main__": InferenceProvider.run_webserver() diff --git a/responses_api_models/inference_provider/tests/test_app.py b/responses_api_models/inference_provider/tests/test_app.py index b95b6395a..da9ba47a9 100644 --- a/responses_api_models/inference_provider/tests/test_app.py +++ b/responses_api_models/inference_provider/tests/test_app.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from unittest.mock import AsyncMock, MagicMock +from aiohttp.client_exceptions import ClientResponseError from fastapi.testclient import TestClient from pytest import MonkeyPatch @@ -72,6 +74,46 @@ def _mock_chat_response(content="Hello!", finish_reason="stop", tool_calls=None, return response +def _provider_error(status: int, message: str, response_body: str | None = None) -> ClientResponseError: + error = ClientResponseError(MagicMock(), (), status=status, message="provider request failed", headers=None) + error.message = message + payload = response_body if response_body is not None else message + error.response_content = payload.encode("utf-8") + return error + + +class _FakeResponseContent: + def __init__(self, response_body: str) -> None: + self._response_body = response_body.encode("utf-8") + self._consumed = False + + async def read(self) -> bytes: + if self._consumed: + return b"" + self._consumed = True + return self._response_body + + +class _FakeRetryResponse: + def __init__(self, status: int, response_body: str) -> None: + self.status = status + self.ok = status < 400 + self.content = _FakeResponseContent(response_body) + self.request_info = MagicMock() + self.request_info.real_url = "https://api.example.com/v1/chat/completions" + + def raise_for_status(self) -> None: + if self.ok: + return + raise ClientResponseError( + self.request_info, + (), + status=self.status, + message="provider request failed", + headers=None, + ) + + class TestSanity: async def test_server_instantiation(self) -> None: server = _make_server() @@ -257,6 +299,252 @@ async def mock_create_chat(**kwargs): assert "reasoning_content" not in data["choices"][0]["message"] +class TestProviderErrors: + async def test_chat_completion_provider_auth_error_is_structured(self, monkeypatch: MonkeyPatch) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + auth_error_content = json.dumps({"error": {"message": "Invalid API key", "type": "authentication_error"}}) + server._client = MagicMock(spec=NeMoGymAsyncOpenAI) + server._client.create_chat_completion = AsyncMock( + side_effect=_provider_error(401, "Authentication failed", auth_error_content) + ) + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "whoami"}]}, + ) + assert response.status_code == 401 + detail = response.json()["detail"] + assert detail["provider_status"] == 401 + assert detail["category"] == "authentication" + assert detail["retryable"] is False + assert detail["provider_context"]["base_url"] == "https://api.example.com/v1" + assert detail["model"] == "test-model" + assert detail["message"] == "Invalid API key" + + async def test_chat_completion_provider_request_error_is_structured(self, monkeypatch: MonkeyPatch) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + request_error_content = json.dumps({"error": {"message": "Missing required field: messages"}}) + server._client = MagicMock(spec=NeMoGymAsyncOpenAI) + server._client.create_chat_completion = AsyncMock( + side_effect=_provider_error(400, "Bad request", request_error_content) + ) + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "bad payload"}]}, + ) + assert response.status_code == 400 + detail = response.json()["detail"] + assert detail["provider_status"] == 400 + assert detail["category"] == "request_error" + assert detail["retryable"] is False + assert detail["message"] == "Missing required field: messages" + + async def test_chat_completion_provider_status_zero_falls_back_to_500(self, monkeypatch: MonkeyPatch) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + server._client = MagicMock(spec=NeMoGymAsyncOpenAI) + server._client.create_chat_completion = AsyncMock( + side_effect=_provider_error(0, "Connection closed by upstream", "Connection closed by upstream") + ) + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert response.status_code == 500 + detail = response.json()["detail"] + assert detail["provider_status"] == 500 + assert detail["category"] == "transient_upstream_failure" + assert detail["retryable"] is True + assert detail["message"] == "Connection closed by upstream" + + async def test_chat_completion_provider_retry_exhausted_500_uses_default_message( + self, monkeypatch: MonkeyPatch + ) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + response_body = json.dumps({"error": {"message": "Upstream provider unavailable"}}) + retry_responses = [_FakeRetryResponse(500, response_body) for _ in range(3)] + + async def mock_request(**kwargs): + return retry_responses.pop(0) + + async def mock_sleep(seconds: float) -> None: + return None + + monkeypatch.setattr("nemo_gym.openai_utils.request", mock_request) + monkeypatch.setattr("nemo_gym.openai_utils.sleep", mock_sleep) + server._client = NeMoGymAsyncOpenAI(base_url=server.config.base_url, api_key=server.config.api_key) + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert response.status_code == 500 + detail = response.json()["detail"] + assert detail["provider_status"] == 500 + assert detail["category"] == "transient_upstream_failure" + assert detail["retryable"] is True + assert detail["message"] == "Provider request failed" + + async def test_chat_completion_provider_error_message_is_truncated_and_compacted( + self, monkeypatch: MonkeyPatch + ) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + long_message = " ".join(["very-long-provider-error-message"] * 20) + error_content = json.dumps({"error": {"message": long_message}}) + server._client = MagicMock(spec=NeMoGymAsyncOpenAI) + server._client.create_chat_completion = AsyncMock( + side_effect=_provider_error(418, "Provider said no", error_content) + ) + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert response.status_code == 418 + detail = response.json()["detail"] + assert detail["provider_status"] == 418 + assert detail["category"] == "provider_error" + assert detail["retryable"] is False + assert len(detail["message"]) == 200 + assert detail["message"].endswith("...") + + async def test_responses_provider_error_uses_same_structured_contract(self, monkeypatch: MonkeyPatch) -> None: + server = _make_server() + app = server.setup_webserver() + client = TestClient(app) + + model_error_content = json.dumps({"error": {"message": "Model test-model does not exist"}}) + server._client = MagicMock(spec=NeMoGymAsyncOpenAI) + server._client.create_chat_completion = AsyncMock( + side_effect=_provider_error(404, "Model not found", model_error_content) + ) + + response = client.post( + "/v1/responses", + json={"input": "whoami"}, + ) + assert response.status_code == 404 + detail = response.json()["detail"] + assert detail["provider_status"] == 404 + assert detail["category"] == "model_not_found" + assert detail["retryable"] is False + assert detail["provider_context"]["base_url"] == "https://api.example.com/v1" + assert detail["message"] == "Model test-model does not exist" + + +class TestProviderErrorHelpers: + def test_build_provider_error_payload_marks_rate_limit_as_retryable(self) -> None: + server = _make_server() + error = _provider_error(429, "Rate limit exceeded", json.dumps({"error": {"message": "Rate limit exceeded"}})) + + payload = server._build_provider_error_payload(error) + + assert payload["provider_status"] == 429 + assert payload["category"] == "rate_limit" + assert payload["retryable"] is True + assert payload["message"] == "Rate limit exceeded" + + def test_build_provider_error_payload_marks_transient_upstream_failures_as_retryable(self) -> None: + server = _make_server() + error = _provider_error( + 503, + "Service unavailable", + json.dumps({"error": {"message": "Upstream provider unavailable"}}), + ) + + payload = server._build_provider_error_payload(error) + + assert payload["provider_status"] == 503 + assert payload["category"] == "transient_upstream_failure" + assert payload["retryable"] is True + assert payload["message"] == "Upstream provider unavailable" + + def test_build_provider_error_payload_preserves_plain_text_provider_bodies(self) -> None: + server = _make_server() + plain_text_error = "gateway timeout while reading upstream response" + error = _provider_error(502, "Bad gateway", plain_text_error) + + payload = server._build_provider_error_payload(error) + + assert payload["provider_status"] == 502 + assert payload["category"] == "transient_upstream_failure" + assert payload["retryable"] is True + assert payload["message"] == plain_text_error + + def test_extract_provider_error_message_reads_top_level_message(self) -> None: + server = _make_server() + error = _provider_error(400, "Bad request", json.dumps({"message": "Top-level message"})) + + assert server._extract_provider_error_message(error) == "Top-level message" + + def test_extract_provider_error_message_reads_top_level_detail(self) -> None: + server = _make_server() + error = _provider_error(400, "Bad request", json.dumps({"detail": "Top-level detail"})) + + assert server._extract_provider_error_message(error) == "Top-level detail" + + def test_extract_provider_error_message_keeps_raw_json_when_shape_is_unknown(self) -> None: + server = _make_server() + response_body = json.dumps({"unexpected": "payload"}) + error = _provider_error(500, "Server error", response_body) + + assert server._extract_provider_error_message(error) == response_body + + def test_extract_provider_error_message_uses_string_response_content_when_non_bytes(self) -> None: + server = _make_server() + error = ClientResponseError(MagicMock(), (), status=502, message="provider request failed", headers=None) + error.response_content = "string-based provider body" + + assert server._extract_provider_error_message(error) == "string-based provider body" + + def test_extract_provider_error_message_falls_back_when_response_content_is_none(self) -> None: + server = _make_server() + request_info = MagicMock() + request_info.real_url = "https://api.example.com/v1/chat/completions" + error = ClientResponseError( + request_info, + (), + status=500, + message="provider request failed", + headers=None, + ) + error.response_content = None + + assert server._extract_provider_error_message(error) == str(error).strip() + + def test_classify_provider_error_uses_message_for_authentication(self) -> None: + server = _make_server() + + assert server._classify_provider_error(418, "api key invalid") == "authentication" + assert server._classify_provider_error(418, "auth token expired") == "authentication" + + def test_classify_provider_error_uses_message_for_model_not_found(self) -> None: + server = _make_server() + + assert server._classify_provider_error(418, "requested model was not found") == "model_not_found" + + def test_classify_provider_error_uses_message_for_rate_limit(self) -> None: + server = _make_server() + + assert server._classify_provider_error(418, "rate limit exceeded upstream") == "rate_limit" + + class TestResponses: async def test_basic_responses(self, monkeypatch: MonkeyPatch) -> None: server = _make_server()