diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index 03ff090f85..48cda7a042 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -129,6 +129,90 @@ def _post_init(self) -> None: self._converter = self.get_converter() + @staticmethod + def _coerce_token_id_list(value: Any, field_name: str) -> List[int]: + if not isinstance(value, list): + raise ValueError(f"Expected {field_name} to be a list.") + try: + return [int(token_id) for token_id in value] + except (TypeError, ValueError) as e: + raise ValueError(f"Expected {field_name} to contain only integer token IDs.") from e + + @staticmethod + def _extract_logprob_content(choice_dict: Dict[str, Any]) -> List[Dict[str, Any]]: + logprobs = choice_dict.get("logprobs") + if not isinstance(logprobs, dict): + raise ValueError("Token ID information requires choice.logprobs to be present.") + content = logprobs.get("content") + if not isinstance(content, list): + raise ValueError("Token ID information requires choice.logprobs.content to be a list.") + return content + + @staticmethod + def _generation_token_ids_from_logprobs(log_probs: List[Dict[str, Any]]) -> List[int]: + generation_token_ids: List[int] = [] + for log_prob in log_probs: + token = log_prob.get("token") + if not isinstance(token, str) or not token.startswith("token_id:"): + raise ValueError("Expected logprob token strings in the form `token_id:`.") + generation_token_ids.append(int(token.removeprefix("token_id:"))) + return generation_token_ids + + @staticmethod + def _generation_log_probs_from_logprobs(log_probs: List[Dict[str, Any]]) -> List[float]: + try: + return [float(log_prob["logprob"]) for log_prob in log_probs] + except (KeyError, TypeError, ValueError) as e: + raise ValueError("Every logprob entry must include numeric `logprob`.") from e + + @staticmethod + def _coerce_float_list(value: Any, field_name: str) -> List[float]: + if not isinstance(value, list): + raise ValueError(f"Expected {field_name} to be a list.") + try: + return [float(logprob) for logprob in value] + except (TypeError, ValueError) as e: + raise ValueError(f"Expected {field_name} to contain only numeric logprobs.") from e + + def _attach_native_token_information( + self, + chat_completion_dict: Dict[str, Any], + choice_dict: Dict[str, Any], + ) -> bool: + response_nvext = chat_completion_dict.get("nvext") + if not isinstance(response_nvext, dict): + return False + engine_data = response_nvext.get("engine_data") + if not isinstance(engine_data, dict): + return False + + prompt_token_ids = self._coerce_token_id_list( + engine_data.get("prompt_token_ids"), + "nvext.engine_data.prompt_token_ids", + ) + generation_token_ids = self._coerce_token_id_list( + engine_data.get("completion_token_ids"), + "nvext.engine_data.completion_token_ids", + ) + generation_log_probs = self._coerce_float_list( + engine_data.get("completion_logprobs"), + "nvext.engine_data.completion_logprobs", + ) + if len(generation_token_ids) != len(generation_log_probs): + raise ValueError( + "Received mismatched completion token IDs " + f"({len(generation_token_ids)}) and logprobs ({len(generation_log_probs)})." + ) + + choice_dict["message"].update( + dict( + prompt_token_ids=prompt_token_ids, + generation_token_ids=generation_token_ids, + generation_log_probs=generation_log_probs, + ) + ) + return True + async def responses( self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body() ) -> NeMoGymResponse: @@ -433,6 +517,25 @@ def _preprocess_chat_completion_create_params(self, request: Request, body_dict: # No user message found — create one with just the audio blocks. body_dict.setdefault("messages", []).append({"role": "user", "content": list(audio_blocks)}) + # Auto-derive `required_prefix_token_ids` from the latest assistant + # message that has per-message token IDs attached. Both Dynamo's + # Rust preprocessor and NeMo-RL's custom vLLM serving mixin honor + # this field to splice verbatim model-emitted tokens into the + # template-tokenized prefix, preserving byte-level token continuity + # across multi-turn replays. The vLLM mixin auto-derives from + # per-message `prompt_token_ids` itself (see + # `nemo_rl/models/generation/vllm/vllm_worker_async.py` + # `NeMoRLOpenAIChatRequestMixin.model_post_init`); Dynamo does not, + # so we set it server-agnostically here. When the vLLM mixin sees + # the field already populated, its auto-derive short-circuits. + if "required_prefix_token_ids" not in body_dict: + for message in reversed(body_dict.get("messages", [])): + if "prompt_token_ids" in message: + body_dict["required_prefix_token_ids"] = list(message["prompt_token_ids"]) + list( + message["generation_token_ids"] + ) + break + return body_dict async def chat_completions( @@ -498,50 +601,37 @@ async def chat_completions( ) if self.config.return_token_id_information and "prompt_token_ids" not in choice_dict["message"]: - log_probs = choice_dict["logprobs"]["content"] - generation_log_probs = [log_prob["logprob"] for log_prob in log_probs] - - """ - START TODO remove this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids - """ - # Looks like `"token_id:151667"` - generation_token_ids = [log_prob["token"].removeprefix("token_id:") for log_prob in log_probs] - - # The tokenize endpoint doesn't accept any sampling parameters - # The only relevant params are model, messages, and tools. - # - # IMPORTANT: pass through chat-template knobs (e.g. enable_thinking) - # when tokenizing, otherwise `prompt_token_ids` (and therefore logged - # `prompt_str`) can be built with different chat template settings than - # the actual generation request. - tokenize_body_dict = dict() - for key in ("model", "messages", "tools", "chat_template_kwargs"): - if key in body_dict: - tokenize_body_dict[key] = body_dict[key] - - # The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the .. - tokenize_response = await client.create_tokenize(**tokenize_body_dict) - """ - END - """ - - message_dict = choice_dict["message"] - message_dict.update( - dict( - # TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids - # prompt_token_ids=chat_completion_dict["prompt_token_ids"], - prompt_token_ids=tokenize_response["tokens"], - # generation_token_ids=choice_dict["token_ids"], - generation_token_ids=generation_token_ids, - generation_log_probs=generation_log_probs, - ) + has_native_tokens = self._attach_native_token_information( + chat_completion_dict, + choice_dict, ) + if not has_native_tokens: + log_probs = self._extract_logprob_content(choice_dict) + generation_log_probs = self._generation_log_probs_from_logprobs(log_probs) + message_dict = choice_dict["message"] + generation_token_ids = self._generation_token_ids_from_logprobs(log_probs) + + # The tokenize endpoint doesn't accept any sampling parameters. + # The only relevant params are model, messages, tools, chat-template + # knobs, and the prefix splice metadata used by the vLLM server. + tokenize_body_dict = dict() + for key in ("model", "messages", "tools", "chat_template_kwargs", "required_prefix_token_ids"): + if key in body_dict: + tokenize_body_dict[key] = body_dict[key] + + # The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the .. + tokenize_response = await client.create_tokenize(**tokenize_body_dict) + + message_dict.update( + dict( + prompt_token_ids=tokenize_response["tokens"], + generation_token_ids=generation_token_ids, + generation_log_probs=generation_log_probs, + ) + ) # Clean the duplicated information choice_dict.pop("logprobs") - # TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids - # chat_completion_dict.pop("prompt_token_ids") - # choice_dict.pop("token_ids") return NeMoGymChatCompletion.model_validate(chat_completion_dict) diff --git a/responses_api_models/vllm_model/tests/test_app.py b/responses_api_models/vllm_model/tests/test_app.py index fe9eb86fd1..6992a687ef 100644 --- a/responses_api_models/vllm_model/tests/test_app.py +++ b/responses_api_models/vllm_model/tests/test_app.py @@ -52,7 +52,7 @@ NeMoGymResponseReasoningItem, NeMoGymSummary, ) -from nemo_gym.server_utils import ServerClient +from nemo_gym.server_utils import SESSION_ID_KEY, ServerClient from responses_api_models.vllm_model.app import ( VLLMConverter, VLLMModel, @@ -3311,6 +3311,182 @@ async def mock_create_chat_completion(**kwargs): assert captured_kwargs["new_param"] == "value" +def _make_token_information_model() -> VLLMModel: + config = VLLMModelConfig( + host="0.0.0.0", + port=8080, + entrypoint="", + name="vllm_model", + base_url="http://localhost:9999/v1", + api_key="dummy_key", # pragma: allowlist secret + model="dummy-model", + return_token_id_information=True, + uses_reasoning_parser=False, + uses_interleaved_reasoning=False, + ) + return VLLMModel(config=config, server_client=MagicMock(spec=ServerClient)) + + +def _make_chat_request() -> MagicMock: + request = MagicMock() + request.session = {SESSION_ID_KEY: "session-1"} + return request + + +class TestTokenIDInformation: + async def test_uses_native_engine_data_without_tokenize(self) -> None: + model = _make_token_information_model() + body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}]) + chat_completion = { + "id": "chtcmpl-123", + "object": "chat.completion", + "created": FIXED_TIME, + "model": "dummy-model", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "hi"}, + "logprobs": { + "content": [ + {"token": "token_id:11", "logprob": -9.1}, + {"token": "token_id:12", "logprob": -9.2}, + ] + }, + } + ], + "nvext": { + "engine_data": { + "prompt_token_ids": [1, 2, 3], + "completion_token_ids": [11, 12], + "completion_logprobs": [-0.1, -0.2], + "finished": True, + }, + }, + } + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock(return_value=chat_completion) + mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize")) + model._clients = [mock_client] + + response = await model.chat_completions(_make_chat_request(), body) + + message = response.choices[0].message + assert message.prompt_token_ids == [1, 2, 3] + assert message.generation_token_ids == [11, 12] + assert message.generation_log_probs == [-0.1, -0.2] + mock_client.create_tokenize.assert_not_awaited() + + async def test_native_engine_data_ignores_postprocessed_choice_logprobs(self) -> None: + model = _make_token_information_model() + body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}]) + chat_completion = { + "id": "chtcmpl-123", + "object": "chat.completion", + "created": FIXED_TIME, + "model": "dummy-model", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": {"role": "assistant", "content": None}, + "logprobs": { + "content": [ + {"token": "token_id:12", "logprob": -9.2}, + ] + }, + } + ], + "nvext": { + "engine_data": { + "prompt_token_ids": [1, 2, 3], + "completion_token_ids": [11, 12], + "completion_logprobs": [-0.1, -0.2], + "finished": True, + }, + }, + } + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock(return_value=chat_completion) + mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize")) + model._clients = [mock_client] + + response = await model.chat_completions(_make_chat_request(), body) + + message = response.choices[0].message + assert message.prompt_token_ids == [1, 2, 3] + assert message.generation_token_ids == [11, 12] + assert message.generation_log_probs == [-0.1, -0.2] + mock_client.create_tokenize.assert_not_awaited() + + async def test_missing_engine_data_keeps_vllm_tokenize_fallback(self) -> None: + model = _make_token_information_model() + body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}]) + chat_completion = { + "id": "chtcmpl-123", + "object": "chat.completion", + "created": FIXED_TIME, + "model": "dummy-model", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "hi"}, + "logprobs": { + "content": [ + {"token": "token_id:21", "logprob": -0.3}, + {"token": "token_id:22", "logprob": -0.4}, + ] + }, + } + ], + } + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock(return_value=chat_completion) + mock_client.create_tokenize = AsyncMock(return_value={"tokens": [5, 6, 7]}) + model._clients = [mock_client] + + response = await model.chat_completions(_make_chat_request(), body) + + message = response.choices[0].message + assert message.prompt_token_ids == [5, 6, 7] + assert message.generation_token_ids == [21, 22] + assert message.generation_log_probs == [-0.3, -0.4] + mock_client.create_tokenize.assert_awaited_once() + + async def test_malformed_engine_data_raises_without_tokenize(self) -> None: + model = _make_token_information_model() + body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}]) + chat_completion = { + "id": "chtcmpl-123", + "object": "chat.completion", + "created": FIXED_TIME, + "model": "dummy-model", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "hi"}, + "logprobs": {"content": [{"token": "token_id:11", "logprob": -0.1}]}, + } + ], + "nvext": {"engine_data": {"completion_token_ids": [11]}}, + } + mock_client = MagicMock(spec=NeMoGymAsyncOpenAI) + mock_client.create_chat_completion = AsyncMock(return_value=chat_completion) + mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize")) + model._clients = [mock_client] + + try: + await model.chat_completions(_make_chat_request(), body) + except ValueError as e: + assert "prompt_token_ids" in str(e) + else: + raise AssertionError("expected malformed native token metadata to raise") + + mock_client.create_tokenize.assert_not_awaited() + + # ────────────────────────────────────────────────────────────────────────────── # Audio sidechannel splice (metadata.audio_data → user-message content block) #