diff --git a/nemo_gym/openai_utils.py b/nemo_gym/openai_utils.py index d5c7ebc4a..5408569ad 100644 --- a/nemo_gym/openai_utils.py +++ b/nemo_gym/openai_utils.py @@ -560,3 +560,20 @@ async def create_tokenize(self, **kwargs): await self._raise_for_status(response, request_kwargs) return await get_response_json(response) + + async def create_generate(self, **kwargs): + # SGLang's native (non-OpenAI) generation endpoint. Used by the SGLang + # engine path because, on the pinned SGLang v0.5.10, /v1/chat/completions + # does not expose the exact sampled integer token ids (its logprobs.token + # is a decoded string and there is no return_tokens_as_token_ids), whereas + # /generate with return_logprob=True returns meta_info.output_token_logprobs + # whose tuples are (logprob, token_id, ...). Lives at the server root, not /v1. + base_url = self.base_url.removesuffix("/v1") + request_kwargs = dict( + url=f"{base_url}/generate", + json=kwargs, + ) + response = await self._request(method="POST", **request_kwargs) + + await self._raise_for_status(response, request_kwargs) + return await get_response_json(response) diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index a6709218f..b764cd99d 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -15,9 +15,10 @@ import base64 import json import os +import re from copy import deepcopy from time import time -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union from uuid import uuid4 from aiohttp.client_exceptions import ClientResponseError @@ -78,6 +79,26 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig): # small without depending on vLLM's ``--allowed-local-media-path``. audio_root: Optional[str] = None + # Generation engine. "vllm" (default) keeps the original OpenAI /v1/chat/completions + # marshaling path unchanged. "sglang" switches to SGLang's native /generate endpoint + # (see VLLMModel._sglang_chat_completion): on the pinned SGLang v0.5.10 the chat endpoint + # cannot return the exact sampled integer token ids (decoded-string logprobs, no + # return_tokens_as_token_ids) and /tokenize only accepts a raw prompt string, so the + # proxy tokenizes locally and reads token ids from /generate's meta_info.output_token_logprobs. + engine: Literal["vllm", "sglang"] = "vllm" + + # Path to the Jinja chat template the SGLang server was launched with (--chat-template). + # Used (engine == "sglang") so the proxy renders prompts with the same template the + # model was served/trained with. If None, the model tokenizer's built-in template is used. + sglang_chat_template_path: Optional[str] = None + + # Max sequence length of the SGLang server (engine == "sglang"). When a request does not + # specify a positive max_tokens (e.g. the SWE agent sends max_output_tokens=0 = "unlimited"), + # the proxy fills the remaining context as max_new_tokens. Without this, SGLang /generate + # falls back to its default max_new_tokens=128, truncating reasoning before and + # breaking multi-turn contiguity. + sglang_max_total_sequence_length: Optional[int] = None + def model_post_init(self, context): if isinstance(self.base_url, str): self.base_url = [self.base_url] @@ -115,6 +136,18 @@ def _post_init(self) -> None: self._converter = self.get_converter() + # Lazily-initialised state for the SGLang engine path (see _sglang_chat_completion). + self._sglang_tokenizer: Any = None + self._sglang_chat_template: Optional[str] = None + # Contiguity fix: per-session running token sequence. Each multi-turn rollout's prompt + # is built by splicing the prior assistant turn's EXACT sampled generation_token_ids + # (never re-tokenizing them), so nemo_gym.py's `seen == prompt[:len(seen)]` holds by + # construction. Re-tokenizing prior turns broke this two ways: proxy parse drift + # (dropped multi-line tool calls, mangled ) and BPE retokenization (identical + # text, different token split). Keyed by SESSION_ID_KEY; cache-miss -> full tokenize. + self._sglang_session_seq: Dict[str, Dict[str, Any]] = dict() + self._sglang_eos_nl_ids: Optional[List[int]] = None + async def responses( self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body() ) -> NeMoGymResponse: @@ -436,6 +469,14 @@ async def chat_completions( self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body() ) -> NeMoGymChatCompletion: body_dict = body.model_dump(exclude_unset=True) + + # SGLang engine path: handled entirely by _sglang_chat_completion, which renders the + # prompt locally (keeping embedded in assistant content, as the SWE chat + # template expects) and generates via /generate. Dispatched BEFORE the vLLM-specific + # _preprocess (which would split reasoning out of assistant content for the chat API). + if self.config.engine == "sglang": + return await self._sglang_chat_completion(request, body_dict) + body_dict = self._preprocess_chat_completion_create_params(request, body_dict) client = self._resolve_client(request) @@ -584,6 +625,332 @@ def _resolve_client(self, request: Request) -> NeMoGymAsyncOpenAI: return client + # ======================================================= + # SGLang engine path (see VLLMModelConfig.engine == "sglang") + # ======================================================= + + # Hermes tool-call format emitted by SGLang's --tool-call-parser hermes and the SWE + # chat template: one or more \n{"name": ..., "arguments": ...}\n + # blocks. The capture is the JSON object, anchored by the closing tag (so nested braces + # in arguments are handled without brace-balancing). + _SGLANG_TOOL_CALL_PATTERN: ClassVar = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + _SGLANG_ARGS_PATTERN: ClassVar = re.compile(r"\"arguments\"\s*:\s*(.*)\}\s*$", re.DOTALL) + # Turn-terminating special tokens the chat template re-emits after an assistant message. + # We strip them from the decoded generation so they are not doubled on history re-render. + _SGLANG_EOS_MARKERS: ClassVar = ("<|im_end|>", "<|endoftext|>") + + def _get_sglang_tokenizer(self) -> Any: + if self._sglang_tokenizer is None: + from transformers import AutoTokenizer + + self._sglang_tokenizer = AutoTokenizer.from_pretrained(self.config.model) + return self._sglang_tokenizer + + def _get_sglang_chat_template(self) -> Optional[str]: + if self._sglang_chat_template is None and self.config.sglang_chat_template_path: + with open(self.config.sglang_chat_template_path) as f: + self._sglang_chat_template = f.read() + return self._sglang_chat_template + + def _full_sglang_tokenize( + self, messages: List[Any], tools: Any, chat_template_kwargs: Dict[str, Any] + ) -> List[int]: + """Tokenize the full chat prompt via the chat template (the original, non-spliced path).""" + tokenizer = self._get_sglang_tokenizer() + encoded = tokenizer.apply_chat_template( + messages, + tools=tools, + chat_template=self._get_sglang_chat_template(), + add_generation_prompt=True, + tokenize=True, + **chat_template_kwargs, + ) + # transformers v5's apply_chat_template(tokenize=True) returns a BatchEncoding + # (dict-like), not a flat list; normalize to a JSON-serializable List[int]. + if isinstance(encoded, dict) or hasattr(encoded, "input_ids"): + encoded = encoded["input_ids"] + if hasattr(encoded, "tolist"): + encoded = encoded.tolist() + if encoded and isinstance(encoded[0], (list, tuple)): + encoded = encoded[0] + return [int(t) for t in encoded] + + def _sglang_eos_nl(self) -> List[int]: + if self._sglang_eos_nl_ids is None: + enc = self._get_sglang_tokenizer()("<|im_end|>\n", add_special_tokens=False) + self._sglang_eos_nl_ids = [int(t) for t in enc["input_ids"]] + return self._sglang_eos_nl_ids + + def _sglang_followup_fragment_ids( + self, new_msgs: List[Any], chat_template_kwargs: Dict[str, Any] + ) -> Optional[List[int]]: + """Token ids for the new (non-assistant) messages + the next generation-prompt header, + rendered as a standalone fragment that follows a prior assistant turn. Derived by + differencing two template renders against an anchor assistant turn, then tokenizing the + suffix. Safe to splice onto the running sequence because the splice boundary is the + assistant turn's ``<|im_end|>\\n`` (a special token), across which byte-level BPE does + not merge (validated). Returns None if the template is not splice-friendly -> caller + falls back to a full re-tokenize.""" + tokenizer = self._get_sglang_tokenizer() + ct = self._get_sglang_chat_template() + anchor = [{"role": "assistant", "content": "X"}] + try: + full = tokenizer.apply_chat_template( + anchor + list(new_msgs), tools=None, chat_template=ct, + add_generation_prompt=True, tokenize=False, **chat_template_kwargs, + ) + base = tokenizer.apply_chat_template( + anchor, tools=None, chat_template=ct, + add_generation_prompt=False, tokenize=False, **chat_template_kwargs, + ) + except Exception: + return None + if not isinstance(full, str) or not isinstance(base, str) or not full.startswith(base): + return None + enc = tokenizer(full[len(base):], add_special_tokens=False) + return [int(t) for t in enc["input_ids"]] + + @staticmethod + def _sglang_msg_sig(m: Dict[str, Any]) -> Tuple[Any, str, str]: + return ( + m.get("role"), + json.dumps(m.get("content"), sort_keys=True, default=str), + json.dumps(m.get("tool_calls"), sort_keys=True, default=str), + ) + + @classmethod + def _sglang_messages_match(cls, a: List[Any], b: List[Any]) -> bool: + return len(a) == len(b) and all( + cls._sglang_msg_sig(x) == cls._sglang_msg_sig(y) for x, y in zip(a, b) + ) + + def _build_sglang_prompt_ids( + self, request: Request, messages: List[Any], tools: Any, + chat_template_kwargs: Dict[str, Any], + ) -> Tuple[List[int], Optional[str]]: + """Return (prompt_token_ids, session_id). Splices the prior assistant turn's exact + generation tokens when this is a continuation of a cached session; else full tokenize.""" + try: + sid = request.session.get(SESSION_ID_KEY) + except Exception: + sid = None + if sid is not None: + state = self._sglang_session_seq.get(sid) + if state is not None: + prev = state["messages"] + n = len(prev) + if ( + len(messages) > n + and messages[n].get("role") == "assistant" + and all(m.get("role") != "assistant" for m in messages[n + 1:]) + and self._sglang_messages_match(messages[:n], prev) + ): + frag = self._sglang_followup_fragment_ids(messages[n + 1:], chat_template_kwargs) + if frag is not None: + return state["seq"] + frag, sid + return self._full_sglang_tokenize(messages, tools, chat_template_kwargs), sid + + def _update_sglang_session_seq( + self, sid: Optional[str], messages: List[Any], + prompt_token_ids: List[int], generation_token_ids: List[int], + ) -> None: + """Cache the running sequence through this assistant turn (prompt + gen + ``<|im_end|>\\n``) + for the next turn's splice.""" + if sid is None: + return + eos_nl = self._sglang_eos_nl() # e.g. [151645, 198] + seq = list(prompt_token_ids) + list(generation_token_ids) + if not seq or seq[-1] != eos_nl[0]: + seq = seq + eos_nl + else: + seq = seq + eos_nl[1:] # gen already ended with <|im_end|>; just add the trailing \n + # Bound memory: refresh this sid's insertion order, evict oldest beyond the cap. + # Evicted sessions simply fall back to a full tokenize on their next turn (safe). + self._sglang_session_seq.pop(sid, None) + while len(self._sglang_session_seq) >= 8192: + self._sglang_session_seq.pop(next(iter(self._sglang_session_seq)), None) + self._sglang_session_seq[sid] = {"messages": list(messages), "seq": seq} + + def _parse_sglang_generation(self, text: str) -> Tuple[Optional[str], str, List[Dict[str, Any]]]: + """Reconstruct (reasoning_content, content, tool_calls) from SGLang /generate raw text. + + The qwen3-thinking generation prompt ends with ``\\n``, so the generated text + begins INSIDE the reasoning block (no opening ````). Everything up to the first + ```` is therefore reasoning; hermes tool calls are parsed out of the remainder. + This mirrors what the SGLang server's reasoning_parser=qwen3-thinking + + tool_call_parser=hermes would have produced on /v1/chat/completions, so downstream + Responses marshaling is identical to the vLLM path. + """ + reasoning_content: Optional[str] = None + if self.config.uses_reasoning_parser and "" in text: + reasoning_content, _, remainder = text.partition("") + else: + remainder = text + + tool_calls: List[Dict[str, Any]] = [] + for match in self._SGLANG_TOOL_CALL_PATTERN.finditer(remainder): + block = match.group(1) + try: + parsed = json.loads(block) + except json.JSONDecodeError: + continue + # Preserve the model's EXACT arguments serialization (function.arguments is a JSON + # string in the OpenAI schema). Keeping the raw substring -- rather than + # re-serializing the parsed dict -- means the chat template re-renders the assistant + # turn byte-identically, which the nemo_gym.py contiguity assert depends on. + args_match = self._SGLANG_ARGS_PATTERN.search(block) + if args_match is not None: + arguments = args_match.group(1).strip() + else: + arguments = json.dumps(parsed.get("arguments", {})) + tool_calls.append( + dict( + id=f"call_{uuid4().hex}", + type="function", + function=dict(name=parsed.get("name"), arguments=arguments), + ) + ) + + content = self._SGLANG_TOOL_CALL_PATTERN.sub("", remainder).strip() + return reasoning_content, content, tool_calls + + async def _sglang_chat_completion( + self, request: Request, body_dict: Dict[str, Any] + ) -> NeMoGymChatCompletion: + """SGLang v0.5.10 generation path (see VLLMModelConfig.engine). + + Tokenizes the chat-templated prompt locally and generates via SGLang's native + /generate (return_logprob=True) -- the only v0.5.10 source of the exact sampled + integer token ids AND their logprobs (needed for token-level RL). The decoded text + is re-parsed into reasoning + hermes tool_calls so the returned object is shaped + exactly like the vLLM /v1/chat/completions response, keeping every downstream + Responses-API conversion identical. + """ + client = self._resolve_client(request) + + messages = body_dict["messages"] + if self.config.replace_developer_role_with_system: + for message_dict in messages: + if message_dict.get("role") == "developer": + message_dict["role"] = "system" + tools = body_dict.get("tools") + + # Merge config chat_template_kwargs with per-request metadata overrides (mirrors + # _preprocess_chat_completion_create_params so reasoning toggles behave identically). + chat_template_kwargs: Dict[str, Any] = {} + if self.config.chat_template_kwargs: + chat_template_kwargs = deepcopy(self.config.chat_template_kwargs) + metadata = body_dict.get("metadata", dict()) + chat_template_kwargs.update(json.loads(metadata.get("chat_template_kwargs", "{}"))) + + tokenizer = self._get_sglang_tokenizer() # used below to decode generation_token_ids + # Build prompt token ids with contiguity-preserving splicing across turns (falls back + # to a full chat-template tokenize on the first turn / cache miss / history condensation). + prompt_token_ids, _splice_sid = self._build_sglang_prompt_ids( + request, messages, tools, chat_template_kwargs + ) + + # Map the OpenAI sampling knobs onto SGLang /generate sampling_params. + # spaces_between_special_tokens=False mirrors the NeMo-RL SGLang backend and keeps + # special tokens (, ) tight in the decoded text we parse below. + sampling_params: Dict[str, Any] = {"spaces_between_special_tokens": False} + # max_tokens of None OR 0 means "unlimited" (the SWE agent sends max_output_tokens=0). + # SGLang /generate would otherwise fall back to max_new_tokens=128 and truncate reasoning + # before ; fill the remaining context instead (matches vLLM + the recipe). + max_new_tokens = body_dict.get("max_tokens") or None + if max_new_tokens is None and self.config.sglang_max_total_sequence_length: + # Fill the remaining context, leaving a small margin: SGLang /generate rejects a + # request whose input + max_new_tokens >= context_length (it requires strictly less, + # unlike vLLM which allows ==). Reserve a few tokens to stay safely under the limit. + max_new_tokens = self.config.sglang_max_total_sequence_length - len(prompt_token_ids) - 8 + max_new_tokens = max(1, max_new_tokens) + if max_new_tokens: + sampling_params["max_new_tokens"] = max_new_tokens + for key in ("temperature", "top_p", "top_k", "stop"): + if body_dict.get(key) is not None: + sampling_params[key] = body_dict[key] + + gen = await client.create_generate( + input_ids=prompt_token_ids, + sampling_params=sampling_params, + return_logprob=True, + ) + + meta_info = gen.get("meta_info") or {} + # Each tuple is (logprob, token_id, ...). Sourcing both ids and logprobs from the SAME + # list guarantees they are 1:1 aligned in count and order -- mirrors + # nemo_rl/models/generation/sglang/sglang_generation.py:generate_one_sample. + output_token_logprobs = meta_info.get("output_token_logprobs") or [] + generation_token_ids = [item[1] for item in output_token_logprobs] + generation_log_probs = [item[0] for item in output_token_logprobs] + + # Contiguity fix: cache the running token sequence through this assistant turn so the + # next turn splices these EXACT generation_token_ids instead of re-tokenizing them. + self._update_sglang_session_seq( + _splice_sid, messages, prompt_token_ids, generation_token_ids + ) + + # Decode the EXACT sampled ids with skip_special_tokens=False so reasoning markers + # ( is a SPECIAL token, id 151668) survive. SGLang's gen["text"] decodes with + # skip_special_tokens=True by default, which STRIPS -> the reasoning never gets + # wrapped -> the wrapper is dropped on history re-render -> the nemo_gym.py + # contiguity assert fires on every multi-turn rollout. spaces_between_special_tokens=False + # keeps the markers tight. Then strip the trailing EOS the chat template re-adds. + generated_text = tokenizer.decode( + generation_token_ids, skip_special_tokens=False, spaces_between_special_tokens=False + ) + _stripped = True + while _stripped: + _stripped = False + generated_text = generated_text.rstrip("\n") + for _eos in self._SGLANG_EOS_MARKERS: + if generated_text.endswith(_eos): + generated_text = generated_text[: -len(_eos)] + _stripped = True + reasoning_content, content, tool_calls = self._parse_sglang_generation(generated_text) + + if (meta_info.get("finish_reason") or {}).get("type") == "length": + finish_reason = "length" + elif tool_calls: + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Re-embed reasoning into tags and prepend to content, identical to the vLLM + # reasoning-parser branch, so postprocess_assistant_message_dict re-extracts it. + if self.config.uses_reasoning_parser and reasoning_content: + content = self._converter._wrap_reasoning_in_think_tags([reasoning_content]) + (content or "") + + message_dict: Dict[str, Any] = dict( + role="assistant", + content=content or None, + tool_calls=tool_calls or None, + ) + if self.config.return_token_id_information: + message_dict.update( + dict( + prompt_token_ids=prompt_token_ids, + generation_token_ids=generation_token_ids, + generation_log_probs=generation_log_probs, + ) + ) + + chat_completion_dict = dict( + id=f"chtcmpl-{uuid4().hex}", + object="chat.completion", + created=int(time()), + model=self.config.model, + choices=[ + dict(index=0, finish_reason=finish_reason, message=message_dict, logprobs=None) + ], + usage=dict( + prompt_tokens=len(prompt_token_ids), + completion_tokens=len(generation_token_ids), + total_tokens=len(prompt_token_ids) + len(generation_token_ids), + ), + ) + return NeMoGymChatCompletion.model_validate(chat_completion_dict) + if __name__ == "__main__": VLLMModel.run_webserver() diff --git a/responses_api_models/vllm_model/pyproject.toml b/responses_api_models/vllm_model/pyproject.toml index 8e52b8b7f..35cfe2c76 100644 --- a/responses_api_models/vllm_model/pyproject.toml +++ b/responses_api_models/vllm_model/pyproject.toml @@ -19,6 +19,10 @@ version = "0.2.0rc0" requires-python = ">=3.12" dependencies = [ "nemo-gym[dev]", + # Used by the SGLang engine path (VLLMModelConfig.engine == "sglang") to render the + # chat template + tokenize prompts locally. Pinned to match the version the SWE recipe + # was validated with; relax/adjust as upstream dependency policy requires. + "transformers==5.3.0", ] [build-system]