From 6725d0397c7070bfca8e794ea9c046ccf92983dc Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 15 Apr 2026 14:25:57 -0700 Subject: [PATCH 1/4] Fixes for gym/RL integration. Signed-off-by: Peter Jin --- .../agenthub/codeact_agent/codeact_agent.py | 6 ++- openhands/agenthub/codex_agent/codex_agent.py | 6 ++- openhands/agenthub/nemo_gym_client.py | 40 ++++++++++++++----- .../agenthub/opencode_agent/opencode_agent.py | 6 ++- openhands/core/message.py | 9 +++++ openhands/memory/conversation_memory.py | 14 +++++++ 6 files changed, 67 insertions(+), 14 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 66c1da6cce41..31fde668142d 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -233,7 +233,11 @@ async def step(self, state: State) -> 'Action': ) } - response = await self.nemo_gym_client.model_call(messages, params['tools']) + response = await self.nemo_gym_client.model_call( + messages, + params['tools'], + request_kwargs={'extra_body': params['extra_body']}, + ) ng_openhands_should_log = os.environ.get("NG_OPENHANDS_SHOULD_LOG", "").lower() == "true" if ng_openhands_should_log: diff --git a/openhands/agenthub/codex_agent/codex_agent.py b/openhands/agenthub/codex_agent/codex_agent.py index b0c22046eb1f..8eec6aa0e12f 100644 --- a/openhands/agenthub/codex_agent/codex_agent.py +++ b/openhands/agenthub/codex_agent/codex_agent.py @@ -179,7 +179,11 @@ async def step(self, state: State) -> "Action": model_name=self.llm.config.model, agent_name=self.name ) } - response = await self.nemo_gym_client.model_call(messages, params["tools"]) + response = await self.nemo_gym_client.model_call( + messages, + params["tools"], + request_kwargs={"extra_body": params["extra_body"]}, + ) logger.debug(f"Response from LLM: {response}") actions = self.response_to_actions(response) logger.debug(f"Actions after response_to_actions: {actions}") diff --git a/openhands/agenthub/nemo_gym_client.py b/openhands/agenthub/nemo_gym_client.py index 01cd1c92e06e..bc4099b4a274 100644 --- a/openhands/agenthub/nemo_gym_client.py +++ b/openhands/agenthub/nemo_gym_client.py @@ -9,7 +9,7 @@ import os import tempfile import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from openhands.core.logger import openhands_logger as logger from nemo_gym.global_config import get_global_config_dict @@ -35,6 +35,15 @@ class NemoGymClient: response = await self.nemo_gym_client.model_call(messages, tools) """ + _PROVIDER_SPECIFIC_FIELD_KEYS = ( + "prompt_token_ids", + "generation_token_ids", + "generation_log_probs", + "prompt_moe_topk_indices", + "generation_moe_topk_indices", + "moe_metadata", + ) + def __init__(self, llm: "LLM") -> None: self.ng_server_client = ServerClient( head_server_config=ServerClient.load_head_server_config(), @@ -47,18 +56,20 @@ async def model_call( self, messages: list["Message"], tools: "list[ChatCompletionToolParam] | None" = None, + request_kwargs: dict[str, Any] | None = None, ) -> "ModelResponse": """Make a model call via the NeMo Gym server, with automatic metrics tracking. Args: messages: Conversation messages (OpenHands Message objects). tools: Optional list of tool definitions for function calling. + request_kwargs: Optional extra chat completion fields to forward. Returns: A validated ModelResponse from the server. """ start_time = time.time() - response = await self._post_completion(messages, tools) + response = await self._post_completion(messages, tools, request_kwargs=request_kwargs) self._update_model_call_time(start_time) return response @@ -70,6 +81,7 @@ async def _post_completion( self, messages: list["Message"], tools: "list[ChatCompletionToolParam] | None" = None, + request_kwargs: dict[str, Any] | None = None, ) -> "ModelResponse": from openhands.llm.llm import ModelResponse @@ -81,19 +93,26 @@ async def _post_completion( } if tools: params["tools"] = tools + if request_kwargs: + params.update({k: v for k, v in request_kwargs.items() if k not in ("messages", "tools")}) - fields_to_remove = [ + core_token_fields = [ "prompt_token_ids", "generation_token_ids", "generation_log_probs", ] + fields_to_remove = core_token_fields + [ + "prompt_moe_topk_indices", + "generation_moe_topk_indices", + "moe_metadata", + ] last_occurrence_idx_seen = False for message in reversed(message_dicts): if last_occurrence_idx_seen: for field in fields_to_remove: if field in message: del message[field] - elif all(field in message for field in fields_to_remove): + elif all(field in message for field in core_token_fields): last_occurrence_idx_seen = True model_response = await self.ng_server_client.post( @@ -109,13 +128,12 @@ async def _post_completion( response: ModelResponse = ModelResponse.model_validate(model_response_json) response_message_dict = model_response_json["choices"][0]["message"] - provider_specific_fields: dict = {} - if response_message_dict.get("prompt_token_ids"): - provider_specific_fields = { - "prompt_token_ids": response_message_dict["prompt_token_ids"], - "generation_token_ids": response_message_dict["generation_token_ids"], - "generation_log_probs": response_message_dict["generation_log_probs"], - } + provider_specific_fields = { + key: response_message_dict[key] + for key in self._PROVIDER_SPECIFIC_FIELD_KEYS + if key in response_message_dict + } + if provider_specific_fields: response._provider_specific_fields = provider_specific_fields self._log_completion( diff --git a/openhands/agenthub/opencode_agent/opencode_agent.py b/openhands/agenthub/opencode_agent/opencode_agent.py index eae56244cb5b..ce4d19db6bd1 100644 --- a/openhands/agenthub/opencode_agent/opencode_agent.py +++ b/openhands/agenthub/opencode_agent/opencode_agent.py @@ -204,7 +204,11 @@ async def step(self, state: State) -> "Action": model_name=self.llm.config.model, agent_name=self.name ) } - response = await self.nemo_gym_client.model_call(messages, params["tools"]) + response = await self.nemo_gym_client.model_call( + messages, + params["tools"], + request_kwargs={"extra_body": params["extra_body"]}, + ) logger.debug(f"Response from LLM: {response}") actions = self.response_to_actions(response) logger.debug(f"Actions after response_to_actions: {actions}") diff --git a/openhands/core/message.py b/openhands/core/message.py index 431f5162cdab..9eb041aa3489 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -70,6 +70,9 @@ class Message(BaseModel): prompt_token_ids: list[int] | None = None generation_token_ids: list[int] | None = None generation_log_probs: list[float] | None = None + prompt_moe_topk_indices: dict[str, Any] | list[Any] | None = None + generation_moe_topk_indices: dict[str, Any] | list[Any] | None = None + moe_metadata: dict[str, Any] | None = None @property def contains_image(self) -> bool: @@ -166,5 +169,11 @@ def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]: message_dict['generation_token_ids'] = self.generation_token_ids if self.generation_log_probs is not None: message_dict['generation_log_probs'] = self.generation_log_probs + if self.prompt_moe_topk_indices is not None: + message_dict['prompt_moe_topk_indices'] = self.prompt_moe_topk_indices + if self.generation_moe_topk_indices is not None: + message_dict['generation_moe_topk_indices'] = self.generation_moe_topk_indices + if self.moe_metadata is not None: + message_dict['moe_metadata'] = self.moe_metadata return message_dict diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index b51810d3729f..ec6a858a3806 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -328,6 +328,13 @@ def _process_action( generation_log_probs=provider_specific_fields.get( 'generation_log_probs' ), + prompt_moe_topk_indices=provider_specific_fields.get( + 'prompt_moe_topk_indices' + ), + generation_moe_topk_indices=provider_specific_fields.get( + 'generation_moe_topk_indices' + ), + moe_metadata=provider_specific_fields.get('moe_metadata'), ) return [] elif isinstance(action, AgentFinishAction): @@ -396,6 +403,13 @@ def _process_action( generation_log_probs=provider_specific_fields.get( 'generation_log_probs' ), + prompt_moe_topk_indices=provider_specific_fields.get( + 'prompt_moe_topk_indices' + ), + generation_moe_topk_indices=provider_specific_fields.get( + 'generation_moe_topk_indices' + ), + moe_metadata=provider_specific_fields.get('moe_metadata'), ) ] elif isinstance(action, CmdRunAction) and action.source == 'user': From 08ccb742757aec08682c78c3eebe51fa599a2492 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Mon, 1 Jun 2026 15:03:41 -0700 Subject: [PATCH 2/4] Refactoring. Signed-off-by: Peter Jin --- openhands/agenthub/nemo_gym_client.py | 4 +--- openhands/core/message.py | 2 +- openhands/memory/conversation_memory.py | 24 ++++++------------------ 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/openhands/agenthub/nemo_gym_client.py b/openhands/agenthub/nemo_gym_client.py index bc4099b4a274..0655e49bc021 100644 --- a/openhands/agenthub/nemo_gym_client.py +++ b/openhands/agenthub/nemo_gym_client.py @@ -129,9 +129,7 @@ async def _post_completion( response_message_dict = model_response_json["choices"][0]["message"] provider_specific_fields = { - key: response_message_dict[key] - for key in self._PROVIDER_SPECIFIC_FIELD_KEYS - if key in response_message_dict + key: response_message_dict[key] for key in self._PROVIDER_SPECIFIC_FIELD_KEYS if key in response_message_dict } if provider_specific_fields: response._provider_specific_fields = provider_specific_fields diff --git a/openhands/core/message.py b/openhands/core/message.py index 9eb041aa3489..30b654d06721 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -72,7 +72,7 @@ class Message(BaseModel): generation_log_probs: list[float] | None = None prompt_moe_topk_indices: dict[str, Any] | list[Any] | None = None generation_moe_topk_indices: dict[str, Any] | list[Any] | None = None - moe_metadata: dict[str, Any] | None = None + moe_metadata: dict[str, Any] | list[Any] | None = None @property def contains_image(self) -> bool: diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index ec6a858a3806..014dcb68cefa 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -328,12 +328,8 @@ def _process_action( generation_log_probs=provider_specific_fields.get( 'generation_log_probs' ), - prompt_moe_topk_indices=provider_specific_fields.get( - 'prompt_moe_topk_indices' - ), - generation_moe_topk_indices=provider_specific_fields.get( - 'generation_moe_topk_indices' - ), + prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'), + generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'), moe_metadata=provider_specific_fields.get('moe_metadata'), ) return [] @@ -397,18 +393,10 @@ def _process_action( role=role, # type: ignore[arg-type] content=content, prompt_token_ids=provider_specific_fields.get('prompt_token_ids'), - generation_token_ids=provider_specific_fields.get( - 'generation_token_ids' - ), - generation_log_probs=provider_specific_fields.get( - 'generation_log_probs' - ), - prompt_moe_topk_indices=provider_specific_fields.get( - 'prompt_moe_topk_indices' - ), - generation_moe_topk_indices=provider_specific_fields.get( - 'generation_moe_topk_indices' - ), + generation_token_ids=provider_specific_fields.get('generation_token_ids'), + generation_log_probs=provider_specific_fields.get('generation_log_probs'), + prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'), + generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'), moe_metadata=provider_specific_fields.get('moe_metadata'), ) ] From 438cce121e582ca589f2a10aed963d2a02d0bbf7 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Mon, 1 Jun 2026 15:07:13 -0700 Subject: [PATCH 3/4] Formatting. Signed-off-by: Peter Jin --- openhands/memory/conversation_memory.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 630000a6d81b..276783b39738 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -322,12 +322,8 @@ def _process_action( else [], tool_calls=assistant_msg.tool_calls, prompt_token_ids=provider_specific_fields.get('prompt_token_ids'), - generation_token_ids=provider_specific_fields.get( - 'generation_token_ids' - ), - generation_log_probs=provider_specific_fields.get( - 'generation_log_probs' - ), + generation_token_ids=provider_specific_fields.get('generation_token_ids'), + generation_log_probs=provider_specific_fields.get('generation_log_probs'), prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'), generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'), moe_metadata=provider_specific_fields.get('moe_metadata'), From 645917e605e8cc2802eece6e963a06975d2f65cc Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Mon, 1 Jun 2026 18:04:48 -0700 Subject: [PATCH 4/4] Accumulate MoE-related data. Signed-off-by: Peter Jin --- openhands/agenthub/nemo_gym_client.py | 75 ++++++++---- tests/unit/agenthub/test_nemo_gym_client.py | 119 ++++++++++++++++++++ 2 files changed, 173 insertions(+), 21 deletions(-) create mode 100644 tests/unit/agenthub/test_nemo_gym_client.py diff --git a/openhands/agenthub/nemo_gym_client.py b/openhands/agenthub/nemo_gym_client.py index 2ffaa0288565..3554125b6fff 100644 --- a/openhands/agenthub/nemo_gym_client.py +++ b/openhands/agenthub/nemo_gym_client.py @@ -35,15 +35,20 @@ class NemoGymClient: response = await self.nemo_gym_client.model_call(messages, tools) """ - _PROVIDER_SPECIFIC_FIELD_KEYS = ( + _CORE_TOKEN_FIELD_KEYS = ( "prompt_token_ids", "generation_token_ids", "generation_log_probs", + ) + + _MOE_FIELD_KEYS = ( "prompt_moe_topk_indices", "generation_moe_topk_indices", "moe_metadata", ) + _PROVIDER_SPECIFIC_FIELD_KEYS = _CORE_TOKEN_FIELD_KEYS + _MOE_FIELD_KEYS + def __init__(self, llm: "LLM") -> None: self.ng_server_client = ServerClient( head_server_config=ServerClient.load_head_server_config(), @@ -77,6 +82,50 @@ async def model_call( # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _as_moe_history_elems(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + @classmethod + def _normalize_request_messages( + cls, message_dicts: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + anchor_idx = -1 + for idx in range(len(message_dicts) - 1, -1, -1): + if all(field in message_dicts[idx] for field in cls._CORE_TOKEN_FIELD_KEYS): + anchor_idx = idx + break + + if anchor_idx < 0: + return message_dicts + + moe_history: dict[str, list[Any]] = {field: [] for field in cls._MOE_FIELD_KEYS} + moe_field_seen = {field: False for field in cls._MOE_FIELD_KEYS} + for idx in range(anchor_idx + 1): + message = message_dicts[idx] + is_anchor = idx == anchor_idx + if not is_anchor: + for field in cls._CORE_TOKEN_FIELD_KEYS: + message.pop(field, None) + for field in cls._MOE_FIELD_KEYS: + if field not in message: + continue + moe_field_seen[field] = True + moe_history[field].extend(cls._as_moe_history_elems(message[field])) + if not is_anchor: + del message[field] + + anchor_message = message_dicts[anchor_idx] + for field in cls._MOE_FIELD_KEYS: + if moe_field_seen[field]: + anchor_message[field] = moe_history[field] + + return message_dicts + async def _post_completion( self, messages: list["Message"], @@ -85,7 +134,9 @@ async def _post_completion( ) -> "ModelResponse": from openhands.llm.llm import ModelResponse - message_dicts = [m.model_dump() for m in messages] + message_dicts = self._normalize_request_messages( + [m.model_dump() for m in messages] + ) params: dict = { "messages": message_dicts, @@ -96,25 +147,6 @@ async def _post_completion( if request_kwargs: params.update({k: v for k, v in request_kwargs.items() if k not in ("messages", "tools")}) - core_token_fields = [ - "prompt_token_ids", - "generation_token_ids", - "generation_log_probs", - ] - fields_to_remove = core_token_fields + [ - "prompt_moe_topk_indices", - "generation_moe_topk_indices", - "moe_metadata", - ] - last_occurrence_idx_seen = False - for message in reversed(message_dicts): - if last_occurrence_idx_seen: - for field in fields_to_remove: - if field in message: - del message[field] - elif all(field in message for field in core_token_fields): - last_occurrence_idx_seen = True - # Measure per-call round-trip latency so it's surfaced in # `Metrics.response_latencies` (and therefore in the eval output.jsonl # via `get_metrics(state)`), mirroring the litellm path in @@ -162,6 +194,7 @@ def _log_completion( ) _d = { "messages": [m.model_dump() for m in messages], + "request_messages": params.get("messages"), "response": model_response_json, "provider_specific_fields": provider_specific_fields, "kwargs": { diff --git a/tests/unit/agenthub/test_nemo_gym_client.py b/tests/unit/agenthub/test_nemo_gym_client.py new file mode 100644 index 000000000000..9bbc2aa524be --- /dev/null +++ b/tests/unit/agenthub/test_nemo_gym_client.py @@ -0,0 +1,119 @@ +import copy +import json +from types import SimpleNamespace + +from openhands.agenthub.nemo_gym_client import NemoGymClient +from openhands.core.message import Message, TextContent + + +def test_normalize_request_messages_accumulates_moe_history_on_anchor(): + messages = [ + { + "role": "assistant", + "content": "older", + "prompt_token_ids": [1], + "generation_token_ids": [2], + "generation_log_probs": [-0.1], + "prompt_moe_topk_indices": {"source": "older"}, + "generation_moe_topk_indices": [{"source": "older-gen"}], + "moe_metadata": {"source": "older-meta"}, + }, + { + "role": "user", + "content": "follow-up", + "prompt_moe_topk_indices": {"source": "user-older"}, + }, + { + "role": "assistant", + "content": "anchor", + "prompt_token_ids": [3], + "generation_token_ids": [4], + "generation_log_probs": [-0.2], + "prompt_moe_topk_indices": [{"source": "anchor"}], + "generation_moe_topk_indices": {"source": "anchor-gen"}, + }, + { + "role": "assistant", + "content": "newer-partial", + "prompt_moe_topk_indices": {"source": "newer"}, + }, + ] + + actual = NemoGymClient._normalize_request_messages(copy.deepcopy(messages)) + + assert "prompt_token_ids" not in actual[0] + assert "generation_token_ids" not in actual[0] + assert "generation_log_probs" not in actual[0] + assert "prompt_moe_topk_indices" not in actual[0] + assert "generation_moe_topk_indices" not in actual[0] + assert "moe_metadata" not in actual[0] + + assert "prompt_moe_topk_indices" not in actual[1] + + assert actual[2]["prompt_token_ids"] == [3] + assert actual[2]["generation_token_ids"] == [4] + assert actual[2]["generation_log_probs"] == [-0.2] + assert actual[2]["prompt_moe_topk_indices"] == [ + {"source": "older"}, + {"source": "user-older"}, + {"source": "anchor"}, + ] + assert actual[2]["generation_moe_topk_indices"] == [ + {"source": "older-gen"}, + {"source": "anchor-gen"}, + ] + assert actual[2]["moe_metadata"] == [{"source": "older-meta"}] + + assert actual[3]["prompt_moe_topk_indices"] == {"source": "newer"} + + +def test_normalize_request_messages_without_anchor_leaves_messages_unchanged(): + messages = [ + { + "role": "assistant", + "content": "older", + "prompt_moe_topk_indices": {"source": "older"}, + }, + { + "role": "assistant", + "content": "newer", + "generation_moe_topk_indices": [{"source": "newer"}], + }, + ] + + expected = copy.deepcopy(messages) + actual = NemoGymClient._normalize_request_messages(messages) + + assert actual == expected + + +def test_log_completion_writes_request_messages(tmp_path): + client = NemoGymClient.__new__(NemoGymClient) + client.llm = SimpleNamespace( + config=SimpleNamespace( + log_completions_folder=str(tmp_path), + model="test-model", + ) + ) + + request_messages = [ + { + "role": "assistant", + "content": "anchor", + "prompt_moe_topk_indices": [{"source": "older"}, {"source": "anchor"}], + } + ] + + client._log_completion( + messages=[Message(role="user", content=[TextContent(text="hello")])], + model_response_json={ + "choices": [{"message": {"role": "assistant", "content": "done"}}] + }, + provider_specific_fields={}, + params={"messages": request_messages}, + ) + + [log_file] = list(tmp_path.iterdir()) + logged = json.loads(log_file.read_text()) + assert logged["request_messages"] == request_messages +