diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 66c1da6cce41..31fde668142d 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -233,7 +233,11 @@ async def step(self, state: State) -> 'Action': ) } - response = await self.nemo_gym_client.model_call(messages, params['tools']) + response = await self.nemo_gym_client.model_call( + messages, + params['tools'], + request_kwargs={'extra_body': params['extra_body']}, + ) ng_openhands_should_log = os.environ.get("NG_OPENHANDS_SHOULD_LOG", "").lower() == "true" if ng_openhands_should_log: diff --git a/openhands/agenthub/codex_agent/codex_agent.py b/openhands/agenthub/codex_agent/codex_agent.py index b0c22046eb1f..8eec6aa0e12f 100644 --- a/openhands/agenthub/codex_agent/codex_agent.py +++ b/openhands/agenthub/codex_agent/codex_agent.py @@ -179,7 +179,11 @@ async def step(self, state: State) -> "Action": model_name=self.llm.config.model, agent_name=self.name ) } - response = await self.nemo_gym_client.model_call(messages, params["tools"]) + response = await self.nemo_gym_client.model_call( + messages, + params["tools"], + request_kwargs={"extra_body": params["extra_body"]}, + ) logger.debug(f"Response from LLM: {response}") actions = self.response_to_actions(response) logger.debug(f"Actions after response_to_actions: {actions}") diff --git a/openhands/agenthub/nemo_gym_client.py b/openhands/agenthub/nemo_gym_client.py index 6e676cab3d68..3554125b6fff 100644 --- a/openhands/agenthub/nemo_gym_client.py +++ b/openhands/agenthub/nemo_gym_client.py @@ -9,7 +9,7 @@ import os import tempfile import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from openhands.core.logger import openhands_logger as logger from nemo_gym.global_config import get_global_config_dict @@ -35,6 +35,20 @@ class NemoGymClient: response = await self.nemo_gym_client.model_call(messages, tools) """ + _CORE_TOKEN_FIELD_KEYS = ( + "prompt_token_ids", + "generation_token_ids", + "generation_log_probs", + ) + + _MOE_FIELD_KEYS = ( + "prompt_moe_topk_indices", + "generation_moe_topk_indices", + "moe_metadata", + ) + + _PROVIDER_SPECIFIC_FIELD_KEYS = _CORE_TOKEN_FIELD_KEYS + _MOE_FIELD_KEYS + def __init__(self, llm: "LLM") -> None: self.ng_server_client = ServerClient( head_server_config=ServerClient.load_head_server_config(), @@ -47,18 +61,20 @@ async def model_call( self, messages: list["Message"], tools: "list[ChatCompletionToolParam] | None" = None, + request_kwargs: dict[str, Any] | None = None, ) -> "ModelResponse": """Make a model call via the NeMo Gym server, with automatic metrics tracking. Args: messages: Conversation messages (OpenHands Message objects). tools: Optional list of tool definitions for function calling. + request_kwargs: Optional extra chat completion fields to forward. Returns: A validated ModelResponse from the server. """ start_time = time.time() - response = await self._post_completion(messages, tools) + response = await self._post_completion(messages, tools, request_kwargs=request_kwargs) self._update_model_call_time(start_time) return response @@ -66,14 +82,61 @@ async def model_call( # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _as_moe_history_elems(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + @classmethod + def _normalize_request_messages( + cls, message_dicts: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + anchor_idx = -1 + for idx in range(len(message_dicts) - 1, -1, -1): + if all(field in message_dicts[idx] for field in cls._CORE_TOKEN_FIELD_KEYS): + anchor_idx = idx + break + + if anchor_idx < 0: + return message_dicts + + moe_history: dict[str, list[Any]] = {field: [] for field in cls._MOE_FIELD_KEYS} + moe_field_seen = {field: False for field in cls._MOE_FIELD_KEYS} + for idx in range(anchor_idx + 1): + message = message_dicts[idx] + is_anchor = idx == anchor_idx + if not is_anchor: + for field in cls._CORE_TOKEN_FIELD_KEYS: + message.pop(field, None) + for field in cls._MOE_FIELD_KEYS: + if field not in message: + continue + moe_field_seen[field] = True + moe_history[field].extend(cls._as_moe_history_elems(message[field])) + if not is_anchor: + del message[field] + + anchor_message = message_dicts[anchor_idx] + for field in cls._MOE_FIELD_KEYS: + if moe_field_seen[field]: + anchor_message[field] = moe_history[field] + + return message_dicts + async def _post_completion( self, messages: list["Message"], tools: "list[ChatCompletionToolParam] | None" = None, + request_kwargs: dict[str, Any] | None = None, ) -> "ModelResponse": from openhands.llm.llm import ModelResponse - message_dicts = [m.model_dump() for m in messages] + message_dicts = self._normalize_request_messages( + [m.model_dump() for m in messages] + ) params: dict = { "messages": message_dicts, @@ -81,20 +144,8 @@ async def _post_completion( } if tools: params["tools"] = tools - - fields_to_remove = [ - "prompt_token_ids", - "generation_token_ids", - "generation_log_probs", - ] - last_occurrence_idx_seen = False - for message in reversed(message_dicts): - if last_occurrence_idx_seen: - for field in fields_to_remove: - if field in message: - del message[field] - elif all(field in message for field in fields_to_remove): - last_occurrence_idx_seen = True + if request_kwargs: + params.update({k: v for k, v in request_kwargs.items() if k not in ("messages", "tools")}) # Measure per-call round-trip latency so it's surfaced in # `Metrics.response_latencies` (and therefore in the eval output.jsonl @@ -118,13 +169,10 @@ async def _post_completion( response: ModelResponse = ModelResponse.model_validate(model_response_json) response_message_dict = model_response_json["choices"][0]["message"] - provider_specific_fields: dict = {} - if response_message_dict.get("prompt_token_ids"): - provider_specific_fields = { - "prompt_token_ids": response_message_dict["prompt_token_ids"], - "generation_token_ids": response_message_dict["generation_token_ids"], - "generation_log_probs": response_message_dict["generation_log_probs"], - } + provider_specific_fields = { + key: response_message_dict[key] for key in self._PROVIDER_SPECIFIC_FIELD_KEYS if key in response_message_dict + } + if provider_specific_fields: response._provider_specific_fields = provider_specific_fields self._log_completion( @@ -146,6 +194,7 @@ def _log_completion( ) _d = { "messages": [m.model_dump() for m in messages], + "request_messages": params.get("messages"), "response": model_response_json, "provider_specific_fields": provider_specific_fields, "kwargs": { diff --git a/openhands/agenthub/opencode_agent/opencode_agent.py b/openhands/agenthub/opencode_agent/opencode_agent.py index eae56244cb5b..ce4d19db6bd1 100644 --- a/openhands/agenthub/opencode_agent/opencode_agent.py +++ b/openhands/agenthub/opencode_agent/opencode_agent.py @@ -204,7 +204,11 @@ async def step(self, state: State) -> "Action": model_name=self.llm.config.model, agent_name=self.name ) } - response = await self.nemo_gym_client.model_call(messages, params["tools"]) + response = await self.nemo_gym_client.model_call( + messages, + params["tools"], + request_kwargs={"extra_body": params["extra_body"]}, + ) logger.debug(f"Response from LLM: {response}") actions = self.response_to_actions(response) logger.debug(f"Actions after response_to_actions: {actions}") diff --git a/openhands/core/message.py b/openhands/core/message.py index 431f5162cdab..30b654d06721 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -70,6 +70,9 @@ class Message(BaseModel): prompt_token_ids: list[int] | None = None generation_token_ids: list[int] | None = None generation_log_probs: list[float] | None = None + prompt_moe_topk_indices: dict[str, Any] | list[Any] | None = None + generation_moe_topk_indices: dict[str, Any] | list[Any] | None = None + moe_metadata: dict[str, Any] | list[Any] | None = None @property def contains_image(self) -> bool: @@ -166,5 +169,11 @@ def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]: message_dict['generation_token_ids'] = self.generation_token_ids if self.generation_log_probs is not None: message_dict['generation_log_probs'] = self.generation_log_probs + if self.prompt_moe_topk_indices is not None: + message_dict['prompt_moe_topk_indices'] = self.prompt_moe_topk_indices + if self.generation_moe_topk_indices is not None: + message_dict['generation_moe_topk_indices'] = self.generation_moe_topk_indices + if self.moe_metadata is not None: + message_dict['moe_metadata'] = self.moe_metadata return message_dict diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index b6b6ac6c244a..276783b39738 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -322,12 +322,11 @@ def _process_action( else [], tool_calls=assistant_msg.tool_calls, prompt_token_ids=provider_specific_fields.get('prompt_token_ids'), - generation_token_ids=provider_specific_fields.get( - 'generation_token_ids' - ), - generation_log_probs=provider_specific_fields.get( - 'generation_log_probs' - ), + generation_token_ids=provider_specific_fields.get('generation_token_ids'), + generation_log_probs=provider_specific_fields.get('generation_log_probs'), + prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'), + generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'), + moe_metadata=provider_specific_fields.get('moe_metadata'), ) return [] elif isinstance(action, AgentFinishAction): @@ -390,12 +389,11 @@ def _process_action( role=role, # type: ignore[arg-type] content=content, prompt_token_ids=provider_specific_fields.get('prompt_token_ids'), - generation_token_ids=provider_specific_fields.get( - 'generation_token_ids' - ), - generation_log_probs=provider_specific_fields.get( - 'generation_log_probs' - ), + generation_token_ids=provider_specific_fields.get('generation_token_ids'), + generation_log_probs=provider_specific_fields.get('generation_log_probs'), + prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'), + generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'), + moe_metadata=provider_specific_fields.get('moe_metadata'), ) ] elif isinstance(action, CmdRunAction) and action.source == 'user': diff --git a/tests/unit/agenthub/test_nemo_gym_client.py b/tests/unit/agenthub/test_nemo_gym_client.py new file mode 100644 index 000000000000..9bbc2aa524be --- /dev/null +++ b/tests/unit/agenthub/test_nemo_gym_client.py @@ -0,0 +1,119 @@ +import copy +import json +from types import SimpleNamespace + +from openhands.agenthub.nemo_gym_client import NemoGymClient +from openhands.core.message import Message, TextContent + + +def test_normalize_request_messages_accumulates_moe_history_on_anchor(): + messages = [ + { + "role": "assistant", + "content": "older", + "prompt_token_ids": [1], + "generation_token_ids": [2], + "generation_log_probs": [-0.1], + "prompt_moe_topk_indices": {"source": "older"}, + "generation_moe_topk_indices": [{"source": "older-gen"}], + "moe_metadata": {"source": "older-meta"}, + }, + { + "role": "user", + "content": "follow-up", + "prompt_moe_topk_indices": {"source": "user-older"}, + }, + { + "role": "assistant", + "content": "anchor", + "prompt_token_ids": [3], + "generation_token_ids": [4], + "generation_log_probs": [-0.2], + "prompt_moe_topk_indices": [{"source": "anchor"}], + "generation_moe_topk_indices": {"source": "anchor-gen"}, + }, + { + "role": "assistant", + "content": "newer-partial", + "prompt_moe_topk_indices": {"source": "newer"}, + }, + ] + + actual = NemoGymClient._normalize_request_messages(copy.deepcopy(messages)) + + assert "prompt_token_ids" not in actual[0] + assert "generation_token_ids" not in actual[0] + assert "generation_log_probs" not in actual[0] + assert "prompt_moe_topk_indices" not in actual[0] + assert "generation_moe_topk_indices" not in actual[0] + assert "moe_metadata" not in actual[0] + + assert "prompt_moe_topk_indices" not in actual[1] + + assert actual[2]["prompt_token_ids"] == [3] + assert actual[2]["generation_token_ids"] == [4] + assert actual[2]["generation_log_probs"] == [-0.2] + assert actual[2]["prompt_moe_topk_indices"] == [ + {"source": "older"}, + {"source": "user-older"}, + {"source": "anchor"}, + ] + assert actual[2]["generation_moe_topk_indices"] == [ + {"source": "older-gen"}, + {"source": "anchor-gen"}, + ] + assert actual[2]["moe_metadata"] == [{"source": "older-meta"}] + + assert actual[3]["prompt_moe_topk_indices"] == {"source": "newer"} + + +def test_normalize_request_messages_without_anchor_leaves_messages_unchanged(): + messages = [ + { + "role": "assistant", + "content": "older", + "prompt_moe_topk_indices": {"source": "older"}, + }, + { + "role": "assistant", + "content": "newer", + "generation_moe_topk_indices": [{"source": "newer"}], + }, + ] + + expected = copy.deepcopy(messages) + actual = NemoGymClient._normalize_request_messages(messages) + + assert actual == expected + + +def test_log_completion_writes_request_messages(tmp_path): + client = NemoGymClient.__new__(NemoGymClient) + client.llm = SimpleNamespace( + config=SimpleNamespace( + log_completions_folder=str(tmp_path), + model="test-model", + ) + ) + + request_messages = [ + { + "role": "assistant", + "content": "anchor", + "prompt_moe_topk_indices": [{"source": "older"}, {"source": "anchor"}], + } + ] + + client._log_completion( + messages=[Message(role="user", content=[TextContent(text="hello")])], + model_response_json={ + "choices": [{"message": {"role": "assistant", "content": "done"}}] + }, + provider_specific_fields={}, + params={"messages": request_messages}, + ) + + [log_file] = list(tmp_path.iterdir()) + logged = json.loads(log_file.read_text()) + assert logged["request_messages"] == request_messages +