diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index a6709218fd..9d4c44e1e7 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -49,6 +49,16 @@ from nemo_gym.server_utils import SESSION_ID_KEY, is_nemo_gym_fastapi_entrypoint +CONTEXT_LENGTH_ERROR_SUBSTRINGS = ( + "context length", + "max_model_len", + "max model len", + "max_tokens", + "maximum context length", + "no room for output tokens", +) + + class VLLMModelConfig(BaseResponsesAPIModelConfig): base_url: Union[str, List[str]] api_key: str @@ -432,6 +442,22 @@ def _preprocess_chat_completion_create_params(self, request: Request, body_dict: return body_dict + @staticmethod + def _is_context_length_error(error: ClientResponseError) -> bool: + if error.status != 400: + return False + + response_content = getattr(error, "response_content", b"") + if isinstance(response_content, bytes): + response_content_text = response_content.decode(errors="replace") + elif response_content is None: + response_content_text = "" + else: + response_content_text = str(response_content) + + error_text = f"{error.message} {response_content_text}".lower() + return any(substring in error_text for substring in CONTEXT_LENGTH_ERROR_SUBSTRINGS) + async def chat_completions( self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body() ) -> NeMoGymChatCompletion: @@ -461,12 +487,7 @@ async def chat_completions( 3. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/entrypoints/openai/serving_engine.py#L948 4. https://github.com/vllm-project/vllm/blob/685c99ee77b4818dcdd15b30fe0e0eff0d5d22ec/vllm/sampling_params.py#L463 """ - result_content_str = e.response_content.decode() - - is_out_of_context_length = e.status == 400 and ( - "context length" in result_content_str or "max_tokens" in result_content_str - ) - if is_out_of_context_length: + if self._is_context_length_error(e): res = self._create_empty_chat_completion() res.choices[0].finish_reason = "length" return res @@ -529,7 +550,14 @@ async def chat_completions( tokenize_body_dict[key] = body_dict[key] # The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the .. - tokenize_response = await client.create_tokenize(**tokenize_body_dict) + try: + tokenize_response = await client.create_tokenize(**tokenize_body_dict) + except ClientResponseError as e: + if self._is_context_length_error(e): + res = self._create_empty_chat_completion() + res.choices[0].finish_reason = "length" + return res + raise """ END """ diff --git a/responses_api_models/vllm_model/tests/test_app.py b/responses_api_models/vllm_model/tests/test_app.py index 176a5b682a..b38c73559b 100644 --- a/responses_api_models/vllm_model/tests/test_app.py +++ b/responses_api_models/vllm_model/tests/test_app.py @@ -16,6 +16,7 @@ from typing import Any, Union from unittest.mock import AsyncMock, MagicMock +from aiohttp.client_exceptions import ClientResponseError from fastapi.testclient import TestClient from pytest import MonkeyPatch, mark, raises @@ -52,7 +53,7 @@ NeMoGymResponseReasoningItem, NeMoGymSummary, ) -from nemo_gym.server_utils import ServerClient +from nemo_gym.server_utils import SESSION_ID_KEY, ServerClient from responses_api_models.vllm_model.app import ( VLLMConverter, VLLMModel, @@ -65,6 +66,18 @@ FIXED_UUID = "123" +def _make_client_response_error(status: int, content: bytes) -> ClientResponseError: + error = ClientResponseError( + request_info=MagicMock(), + history=(), + status=status, + message="Bad Request", + headers=None, + ) + error.response_content = content + return error + + class FakeUUID: """Used for mocking UUIDs""" @@ -661,7 +674,7 @@ class FakeUUID: class TestApp: - def _setup_server(self, monkeypatch: MonkeyPatch): + def _setup_server(self, monkeypatch: MonkeyPatch, return_token_id_information: bool = False): config = VLLMModelConfig( host="0.0.0.0", port=8081, @@ -670,7 +683,7 @@ def _setup_server(self, monkeypatch: MonkeyPatch): model="dummy_model", entrypoint="", name="", - return_token_id_information=False, + return_token_id_information=return_token_id_information, uses_reasoning_parser=False, ) @@ -683,6 +696,76 @@ def _setup_server(self, monkeypatch: MonkeyPatch): async def test_sanity(self, monkeypatch: MonkeyPatch) -> None: self._setup_server(monkeypatch) + async def test_chat_completions_converts_max_model_len_400_to_length( + self, monkeypatch: MonkeyPatch + ) -> None: + server = self._setup_server(monkeypatch) + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock( + side_effect=_make_client_response_error( + 400, + b'{"error":{"message":"Prompt length (8192) fills or exceeds max_model_len (8192). ' + b'No room for output tokens."}}', + ) + ) + server._clients = [mock_client] + + result = await server.chat_completions( + MagicMock(session={SESSION_ID_KEY: "test-session"}), + NeMoGymChatCompletionCreateParamsNonStreaming( + messages=[{"role": "user", "content": "hello"}] + ), + ) + + assert result.choices[0].finish_reason == "length" + + async def test_tokenize_context_length_400_returns_length( + self, monkeypatch: MonkeyPatch + ) -> None: + server = self._setup_server(monkeypatch, return_token_id_information=True) + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock( + return_value={ + "id": "chtcmpl-123", + "object": "chat.completion", + "created": FIXED_TIME, + "model": "dummy_model", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "hello"}, + "logprobs": { + "content": [ + { + "token": "token_id:1", + "bytes": None, + "logprob": -0.1, + "top_logprobs": [], + } + ] + }, + } + ], + } + ) + mock_client.create_tokenize = AsyncMock( + side_effect=_make_client_response_error( + 400, + b'{"error":{"message":"This model\'s maximum context length is 8192 tokens."}}', + ) + ) + server._clients = [mock_client] + + result = await server.chat_completions( + MagicMock(session={SESSION_ID_KEY: "test-session"}), + NeMoGymChatCompletionCreateParamsNonStreaming( + messages=[{"role": "user", "content": "hello"}] + ), + ) + + assert result.choices[0].finish_reason == "length" + def test_responses_multistep(self, monkeypatch: MonkeyPatch): server = self._setup_server(monkeypatch) app = server.setup_webserver()