Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 130 additions & 40 deletions responses_api_models/vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,90 @@ def _post_init(self) -> None:

self._converter = self.get_converter()

@staticmethod
def _coerce_token_id_list(value: Any, field_name: str) -> List[int]:
if not isinstance(value, list):
raise ValueError(f"Expected {field_name} to be a list.")
try:
return [int(token_id) for token_id in value]
except (TypeError, ValueError) as e:
raise ValueError(f"Expected {field_name} to contain only integer token IDs.") from e

@staticmethod
def _extract_logprob_content(choice_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
logprobs = choice_dict.get("logprobs")
if not isinstance(logprobs, dict):
raise ValueError("Token ID information requires choice.logprobs to be present.")
content = logprobs.get("content")
if not isinstance(content, list):
raise ValueError("Token ID information requires choice.logprobs.content to be a list.")
return content

@staticmethod
def _generation_token_ids_from_logprobs(log_probs: List[Dict[str, Any]]) -> List[int]:
generation_token_ids: List[int] = []
for log_prob in log_probs:
token = log_prob.get("token")
if not isinstance(token, str) or not token.startswith("token_id:"):
raise ValueError("Expected logprob token strings in the form `token_id:<id>`.")
generation_token_ids.append(int(token.removeprefix("token_id:")))
return generation_token_ids

@staticmethod
def _generation_log_probs_from_logprobs(log_probs: List[Dict[str, Any]]) -> List[float]:
try:
return [float(log_prob["logprob"]) for log_prob in log_probs]
except (KeyError, TypeError, ValueError) as e:
raise ValueError("Every logprob entry must include numeric `logprob`.") from e

@staticmethod
def _coerce_float_list(value: Any, field_name: str) -> List[float]:
if not isinstance(value, list):
raise ValueError(f"Expected {field_name} to be a list.")
try:
return [float(logprob) for logprob in value]
except (TypeError, ValueError) as e:
raise ValueError(f"Expected {field_name} to contain only numeric logprobs.") from e

def _attach_native_token_information(
self,
chat_completion_dict: Dict[str, Any],
choice_dict: Dict[str, Any],
) -> bool:
response_nvext = chat_completion_dict.get("nvext")
if not isinstance(response_nvext, dict):
return False
engine_data = response_nvext.get("engine_data")
if not isinstance(engine_data, dict):
return False

prompt_token_ids = self._coerce_token_id_list(
engine_data.get("prompt_token_ids"),
"nvext.engine_data.prompt_token_ids",
)
generation_token_ids = self._coerce_token_id_list(
engine_data.get("completion_token_ids"),
"nvext.engine_data.completion_token_ids",
)
generation_log_probs = self._coerce_float_list(
engine_data.get("completion_logprobs"),
"nvext.engine_data.completion_logprobs",
)
if len(generation_token_ids) != len(generation_log_probs):
raise ValueError(
"Received mismatched completion token IDs "
f"({len(generation_token_ids)}) and logprobs ({len(generation_log_probs)})."
)

choice_dict["message"].update(
dict(
prompt_token_ids=prompt_token_ids,
generation_token_ids=generation_token_ids,
generation_log_probs=generation_log_probs,
)
)
return True

async def responses(
self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body()
) -> NeMoGymResponse:
Expand Down Expand Up @@ -433,6 +517,25 @@ def _preprocess_chat_completion_create_params(self, request: Request, body_dict:
# No user message found — create one with just the audio blocks.
body_dict.setdefault("messages", []).append({"role": "user", "content": list(audio_blocks)})

# Auto-derive `required_prefix_token_ids` from the latest assistant
# message that has per-message token IDs attached. Both Dynamo's
# Rust preprocessor and NeMo-RL's custom vLLM serving mixin honor
# this field to splice verbatim model-emitted tokens into the
# template-tokenized prefix, preserving byte-level token continuity
# across multi-turn replays. The vLLM mixin auto-derives from
# per-message `prompt_token_ids` itself (see
# `nemo_rl/models/generation/vllm/vllm_worker_async.py`
# `NeMoRLOpenAIChatRequestMixin.model_post_init`); Dynamo does not,
# so we set it server-agnostically here. When the vLLM mixin sees
# the field already populated, its auto-derive short-circuits.
if "required_prefix_token_ids" not in body_dict:
for message in reversed(body_dict.get("messages", [])):
if "prompt_token_ids" in message:
body_dict["required_prefix_token_ids"] = list(message["prompt_token_ids"]) + list(
message["generation_token_ids"]
)
break

return body_dict

async def chat_completions(
Expand Down Expand Up @@ -498,50 +601,37 @@ async def chat_completions(
)

if self.config.return_token_id_information and "prompt_token_ids" not in choice_dict["message"]:
log_probs = choice_dict["logprobs"]["content"]
generation_log_probs = [log_prob["logprob"] for log_prob in log_probs]

"""
START TODO remove this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids
"""
# Looks like `"token_id:151667"`
generation_token_ids = [log_prob["token"].removeprefix("token_id:") for log_prob in log_probs]

# The tokenize endpoint doesn't accept any sampling parameters
# The only relevant params are model, messages, and tools.
#
# IMPORTANT: pass through chat-template knobs (e.g. enable_thinking)
# when tokenizing, otherwise `prompt_token_ids` (and therefore logged
# `prompt_str`) can be built with different chat template settings than
# the actual generation request.
tokenize_body_dict = dict()
for key in ("model", "messages", "tools", "chat_template_kwargs"):
if key in body_dict:
tokenize_body_dict[key] = body_dict[key]

# The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the ..
tokenize_response = await client.create_tokenize(**tokenize_body_dict)
"""
END
"""

message_dict = choice_dict["message"]
message_dict.update(
dict(
# TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids
# prompt_token_ids=chat_completion_dict["prompt_token_ids"],
prompt_token_ids=tokenize_response["tokens"],
# generation_token_ids=choice_dict["token_ids"],
generation_token_ids=generation_token_ids,
generation_log_probs=generation_log_probs,
)
has_native_tokens = self._attach_native_token_information(
chat_completion_dict,
choice_dict,
)
if not has_native_tokens:
log_probs = self._extract_logprob_content(choice_dict)
generation_log_probs = self._generation_log_probs_from_logprobs(log_probs)
message_dict = choice_dict["message"]
generation_token_ids = self._generation_token_ids_from_logprobs(log_probs)

# The tokenize endpoint doesn't accept any sampling parameters.
# The only relevant params are model, messages, tools, chat-template
# knobs, and the prefix splice metadata used by the vLLM server.
tokenize_body_dict = dict()
for key in ("model", "messages", "tools", "chat_template_kwargs", "required_prefix_token_ids"):
if key in body_dict:
tokenize_body_dict[key] = body_dict[key]

# The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the ..
tokenize_response = await client.create_tokenize(**tokenize_body_dict)

message_dict.update(
dict(
prompt_token_ids=tokenize_response["tokens"],
generation_token_ids=generation_token_ids,
generation_log_probs=generation_log_probs,
)
)

# Clean the duplicated information
choice_dict.pop("logprobs")
# TODO add this when NeMo RL upgrades to vLLM 0.10.2 support for prompt token ids
# chat_completion_dict.pop("prompt_token_ids")
# choice_dict.pop("token_ids")

return NeMoGymChatCompletion.model_validate(chat_completion_dict)

Expand Down
178 changes: 177 additions & 1 deletion responses_api_models/vllm_model/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
NeMoGymResponseReasoningItem,
NeMoGymSummary,
)
from nemo_gym.server_utils import ServerClient
from nemo_gym.server_utils import SESSION_ID_KEY, ServerClient
from responses_api_models.vllm_model.app import (
VLLMConverter,
VLLMModel,
Expand Down Expand Up @@ -3311,6 +3311,182 @@ async def mock_create_chat_completion(**kwargs):
assert captured_kwargs["new_param"] == "value"


def _make_token_information_model() -> VLLMModel:
config = VLLMModelConfig(
host="0.0.0.0",
port=8080,
entrypoint="",
name="vllm_model",
base_url="http://localhost:9999/v1",
api_key="dummy_key", # pragma: allowlist secret
model="dummy-model",
return_token_id_information=True,
uses_reasoning_parser=False,
uses_interleaved_reasoning=False,
)
return VLLMModel(config=config, server_client=MagicMock(spec=ServerClient))


def _make_chat_request() -> MagicMock:
request = MagicMock()
request.session = {SESSION_ID_KEY: "session-1"}
return request


class TestTokenIDInformation:
async def test_uses_native_engine_data_without_tokenize(self) -> None:
model = _make_token_information_model()
body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}])
chat_completion = {
"id": "chtcmpl-123",
"object": "chat.completion",
"created": FIXED_TIME,
"model": "dummy-model",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "hi"},
"logprobs": {
"content": [
{"token": "token_id:11", "logprob": -9.1},
{"token": "token_id:12", "logprob": -9.2},
]
},
}
],
"nvext": {
"engine_data": {
"prompt_token_ids": [1, 2, 3],
"completion_token_ids": [11, 12],
"completion_logprobs": [-0.1, -0.2],
"finished": True,
},
},
}
mock_client = MagicMock(spec=NeMoGymAsyncOpenAI)
mock_client.create_chat_completion = AsyncMock(return_value=chat_completion)
mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize"))
model._clients = [mock_client]

response = await model.chat_completions(_make_chat_request(), body)

message = response.choices[0].message
assert message.prompt_token_ids == [1, 2, 3]
assert message.generation_token_ids == [11, 12]
assert message.generation_log_probs == [-0.1, -0.2]
mock_client.create_tokenize.assert_not_awaited()

async def test_native_engine_data_ignores_postprocessed_choice_logprobs(self) -> None:
model = _make_token_information_model()
body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}])
chat_completion = {
"id": "chtcmpl-123",
"object": "chat.completion",
"created": FIXED_TIME,
"model": "dummy-model",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {"role": "assistant", "content": None},
"logprobs": {
"content": [
{"token": "token_id:12", "logprob": -9.2},
]
},
}
],
"nvext": {
"engine_data": {
"prompt_token_ids": [1, 2, 3],
"completion_token_ids": [11, 12],
"completion_logprobs": [-0.1, -0.2],
"finished": True,
},
},
}
mock_client = MagicMock(spec=NeMoGymAsyncOpenAI)
mock_client.create_chat_completion = AsyncMock(return_value=chat_completion)
mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize"))
model._clients = [mock_client]

response = await model.chat_completions(_make_chat_request(), body)

message = response.choices[0].message
assert message.prompt_token_ids == [1, 2, 3]
assert message.generation_token_ids == [11, 12]
assert message.generation_log_probs == [-0.1, -0.2]
mock_client.create_tokenize.assert_not_awaited()

async def test_missing_engine_data_keeps_vllm_tokenize_fallback(self) -> None:
model = _make_token_information_model()
body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}])
chat_completion = {
"id": "chtcmpl-123",
"object": "chat.completion",
"created": FIXED_TIME,
"model": "dummy-model",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "hi"},
"logprobs": {
"content": [
{"token": "token_id:21", "logprob": -0.3},
{"token": "token_id:22", "logprob": -0.4},
]
},
}
],
}
mock_client = MagicMock(spec=NeMoGymAsyncOpenAI)
mock_client.create_chat_completion = AsyncMock(return_value=chat_completion)
mock_client.create_tokenize = AsyncMock(return_value={"tokens": [5, 6, 7]})
model._clients = [mock_client]

response = await model.chat_completions(_make_chat_request(), body)

message = response.choices[0].message
assert message.prompt_token_ids == [5, 6, 7]
assert message.generation_token_ids == [21, 22]
assert message.generation_log_probs == [-0.3, -0.4]
mock_client.create_tokenize.assert_awaited_once()

async def test_malformed_engine_data_raises_without_tokenize(self) -> None:
model = _make_token_information_model()
body = NeMoGymChatCompletionCreateParamsNonStreaming(messages=[{"role": "user", "content": "hello"}])
chat_completion = {
"id": "chtcmpl-123",
"object": "chat.completion",
"created": FIXED_TIME,
"model": "dummy-model",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "hi"},
"logprobs": {"content": [{"token": "token_id:11", "logprob": -0.1}]},
}
],
"nvext": {"engine_data": {"completion_token_ids": [11]}},
}
mock_client = MagicMock(spec=NeMoGymAsyncOpenAI)
mock_client.create_chat_completion = AsyncMock(return_value=chat_completion)
mock_client.create_tokenize = AsyncMock(side_effect=AssertionError("must not tokenize"))
model._clients = [mock_client]

try:
await model.chat_completions(_make_chat_request(), body)
except ValueError as e:
assert "prompt_token_ids" in str(e)
else:
raise AssertionError("expected malformed native token metadata to raise")

mock_client.create_tokenize.assert_not_awaited()


# ──────────────────────────────────────────────────────────────────────────────
# Audio sidechannel splice (metadata.audio_data → user-message content block)
#
Expand Down