Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ async def step(self, state: State) -> 'Action':
)
}

response = await self.nemo_gym_client.model_call(messages, params['tools'])
response = await self.nemo_gym_client.model_call(
messages,
params['tools'],
request_kwargs={'extra_body': params['extra_body']},
)

ng_openhands_should_log = os.environ.get("NG_OPENHANDS_SHOULD_LOG", "").lower() == "true"
if ng_openhands_should_log:
Expand Down
6 changes: 5 additions & 1 deletion openhands/agenthub/codex_agent/codex_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ async def step(self, state: State) -> "Action":
model_name=self.llm.config.model, agent_name=self.name
)
}
response = await self.nemo_gym_client.model_call(messages, params["tools"])
response = await self.nemo_gym_client.model_call(
messages,
params["tools"],
request_kwargs={"extra_body": params["extra_body"]},
)
logger.debug(f"Response from LLM: {response}")
actions = self.response_to_actions(response)
logger.debug(f"Actions after response_to_actions: {actions}")
Expand Down
97 changes: 73 additions & 24 deletions openhands/agenthub/nemo_gym_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import tempfile
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from openhands.core.logger import openhands_logger as logger
from nemo_gym.global_config import get_global_config_dict
Expand All @@ -35,6 +35,20 @@ class NemoGymClient:
response = await self.nemo_gym_client.model_call(messages, tools)
"""

_CORE_TOKEN_FIELD_KEYS = (
"prompt_token_ids",
"generation_token_ids",
"generation_log_probs",
)

_MOE_FIELD_KEYS = (
"prompt_moe_topk_indices",
"generation_moe_topk_indices",
"moe_metadata",
)

_PROVIDER_SPECIFIC_FIELD_KEYS = _CORE_TOKEN_FIELD_KEYS + _MOE_FIELD_KEYS

def __init__(self, llm: "LLM") -> None:
self.ng_server_client = ServerClient(
head_server_config=ServerClient.load_head_server_config(),
Expand All @@ -47,54 +61,91 @@ async def model_call(
self,
messages: list["Message"],
tools: "list[ChatCompletionToolParam] | None" = None,
request_kwargs: dict[str, Any] | None = None,
) -> "ModelResponse":
"""Make a model call via the NeMo Gym server, with automatic metrics tracking.

Args:
messages: Conversation messages (OpenHands Message objects).
tools: Optional list of tool definitions for function calling.
request_kwargs: Optional extra chat completion fields to forward.

Returns:
A validated ModelResponse from the server.
"""
start_time = time.time()
response = await self._post_completion(messages, tools)
response = await self._post_completion(messages, tools, request_kwargs=request_kwargs)
self._update_model_call_time(start_time)
return response

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

@staticmethod
def _as_moe_history_elems(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
return [value]

@classmethod
def _normalize_request_messages(
cls, message_dicts: list[dict[str, Any]]
) -> list[dict[str, Any]]:
anchor_idx = -1
for idx in range(len(message_dicts) - 1, -1, -1):
if all(field in message_dicts[idx] for field in cls._CORE_TOKEN_FIELD_KEYS):
anchor_idx = idx
break

if anchor_idx < 0:
return message_dicts

moe_history: dict[str, list[Any]] = {field: [] for field in cls._MOE_FIELD_KEYS}
moe_field_seen = {field: False for field in cls._MOE_FIELD_KEYS}
for idx in range(anchor_idx + 1):
message = message_dicts[idx]
is_anchor = idx == anchor_idx
if not is_anchor:
for field in cls._CORE_TOKEN_FIELD_KEYS:
message.pop(field, None)
for field in cls._MOE_FIELD_KEYS:
if field not in message:
continue
moe_field_seen[field] = True
moe_history[field].extend(cls._as_moe_history_elems(message[field]))
if not is_anchor:
del message[field]

anchor_message = message_dicts[anchor_idx]
for field in cls._MOE_FIELD_KEYS:
if moe_field_seen[field]:
anchor_message[field] = moe_history[field]

return message_dicts

async def _post_completion(
self,
messages: list["Message"],
tools: "list[ChatCompletionToolParam] | None" = None,
request_kwargs: dict[str, Any] | None = None,
) -> "ModelResponse":
from openhands.llm.llm import ModelResponse

message_dicts = [m.model_dump() for m in messages]
message_dicts = self._normalize_request_messages(
[m.model_dump() for m in messages]
)

params: dict = {
"messages": message_dicts,
**self.llm._nemo_gym_llm_kwargs,
}
if tools:
params["tools"] = tools

fields_to_remove = [
"prompt_token_ids",
"generation_token_ids",
"generation_log_probs",
]
last_occurrence_idx_seen = False
for message in reversed(message_dicts):
if last_occurrence_idx_seen:
for field in fields_to_remove:
if field in message:
del message[field]
elif all(field in message for field in fields_to_remove):
last_occurrence_idx_seen = True
if request_kwargs:
params.update({k: v for k, v in request_kwargs.items() if k not in ("messages", "tools")})

# Measure per-call round-trip latency so it's surfaced in
# `Metrics.response_latencies` (and therefore in the eval output.jsonl
Expand All @@ -118,13 +169,10 @@ async def _post_completion(
response: ModelResponse = ModelResponse.model_validate(model_response_json)

response_message_dict = model_response_json["choices"][0]["message"]
provider_specific_fields: dict = {}
if response_message_dict.get("prompt_token_ids"):
provider_specific_fields = {
"prompt_token_ids": response_message_dict["prompt_token_ids"],
"generation_token_ids": response_message_dict["generation_token_ids"],
"generation_log_probs": response_message_dict["generation_log_probs"],
}
provider_specific_fields = {
key: response_message_dict[key] for key in self._PROVIDER_SPECIFIC_FIELD_KEYS if key in response_message_dict
}
if provider_specific_fields:
response._provider_specific_fields = provider_specific_fields

self._log_completion(
Expand All @@ -146,6 +194,7 @@ def _log_completion(
)
_d = {
"messages": [m.model_dump() for m in messages],
"request_messages": params.get("messages"),
"response": model_response_json,
"provider_specific_fields": provider_specific_fields,
"kwargs": {
Expand Down
6 changes: 5 additions & 1 deletion openhands/agenthub/opencode_agent/opencode_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ async def step(self, state: State) -> "Action":
model_name=self.llm.config.model, agent_name=self.name
)
}
response = await self.nemo_gym_client.model_call(messages, params["tools"])
response = await self.nemo_gym_client.model_call(
messages,
params["tools"],
request_kwargs={"extra_body": params["extra_body"]},
)
logger.debug(f"Response from LLM: {response}")
actions = self.response_to_actions(response)
logger.debug(f"Actions after response_to_actions: {actions}")
Expand Down
9 changes: 9 additions & 0 deletions openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Message(BaseModel):
prompt_token_ids: list[int] | None = None
generation_token_ids: list[int] | None = None
generation_log_probs: list[float] | None = None
prompt_moe_topk_indices: dict[str, Any] | list[Any] | None = None
generation_moe_topk_indices: dict[str, Any] | list[Any] | None = None
moe_metadata: dict[str, Any] | list[Any] | None = None

@property
def contains_image(self) -> bool:
Expand Down Expand Up @@ -166,5 +169,11 @@ def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]:
message_dict['generation_token_ids'] = self.generation_token_ids
if self.generation_log_probs is not None:
message_dict['generation_log_probs'] = self.generation_log_probs
if self.prompt_moe_topk_indices is not None:
message_dict['prompt_moe_topk_indices'] = self.prompt_moe_topk_indices
if self.generation_moe_topk_indices is not None:
message_dict['generation_moe_topk_indices'] = self.generation_moe_topk_indices
if self.moe_metadata is not None:
message_dict['moe_metadata'] = self.moe_metadata

return message_dict
22 changes: 10 additions & 12 deletions openhands/memory/conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,11 @@ def _process_action(
else [],
tool_calls=assistant_msg.tool_calls,
prompt_token_ids=provider_specific_fields.get('prompt_token_ids'),
generation_token_ids=provider_specific_fields.get(
'generation_token_ids'
),
generation_log_probs=provider_specific_fields.get(
'generation_log_probs'
),
generation_token_ids=provider_specific_fields.get('generation_token_ids'),
generation_log_probs=provider_specific_fields.get('generation_log_probs'),
prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'),
generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'),
moe_metadata=provider_specific_fields.get('moe_metadata'),
)
return []
elif isinstance(action, AgentFinishAction):
Expand Down Expand Up @@ -390,12 +389,11 @@ def _process_action(
role=role, # type: ignore[arg-type]
content=content,
prompt_token_ids=provider_specific_fields.get('prompt_token_ids'),
generation_token_ids=provider_specific_fields.get(
'generation_token_ids'
),
generation_log_probs=provider_specific_fields.get(
'generation_log_probs'
),
generation_token_ids=provider_specific_fields.get('generation_token_ids'),
generation_log_probs=provider_specific_fields.get('generation_log_probs'),
prompt_moe_topk_indices=provider_specific_fields.get('prompt_moe_topk_indices'),
generation_moe_topk_indices=provider_specific_fields.get('generation_moe_topk_indices'),
moe_metadata=provider_specific_fields.get('moe_metadata'),
)
]
elif isinstance(action, CmdRunAction) and action.source == 'user':
Expand Down
119 changes: 119 additions & 0 deletions tests/unit/agenthub/test_nemo_gym_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import copy
import json
from types import SimpleNamespace

from openhands.agenthub.nemo_gym_client import NemoGymClient
from openhands.core.message import Message, TextContent


def test_normalize_request_messages_accumulates_moe_history_on_anchor():
messages = [
{
"role": "assistant",
"content": "older",
"prompt_token_ids": [1],
"generation_token_ids": [2],
"generation_log_probs": [-0.1],
"prompt_moe_topk_indices": {"source": "older"},
"generation_moe_topk_indices": [{"source": "older-gen"}],
"moe_metadata": {"source": "older-meta"},
},
{
"role": "user",
"content": "follow-up",
"prompt_moe_topk_indices": {"source": "user-older"},
},
{
"role": "assistant",
"content": "anchor",
"prompt_token_ids": [3],
"generation_token_ids": [4],
"generation_log_probs": [-0.2],
"prompt_moe_topk_indices": [{"source": "anchor"}],
"generation_moe_topk_indices": {"source": "anchor-gen"},
},
{
"role": "assistant",
"content": "newer-partial",
"prompt_moe_topk_indices": {"source": "newer"},
},
]

actual = NemoGymClient._normalize_request_messages(copy.deepcopy(messages))

assert "prompt_token_ids" not in actual[0]
assert "generation_token_ids" not in actual[0]
assert "generation_log_probs" not in actual[0]
assert "prompt_moe_topk_indices" not in actual[0]
assert "generation_moe_topk_indices" not in actual[0]
assert "moe_metadata" not in actual[0]

assert "prompt_moe_topk_indices" not in actual[1]

assert actual[2]["prompt_token_ids"] == [3]
assert actual[2]["generation_token_ids"] == [4]
assert actual[2]["generation_log_probs"] == [-0.2]
assert actual[2]["prompt_moe_topk_indices"] == [
{"source": "older"},
{"source": "user-older"},
{"source": "anchor"},
]
assert actual[2]["generation_moe_topk_indices"] == [
{"source": "older-gen"},
{"source": "anchor-gen"},
]
assert actual[2]["moe_metadata"] == [{"source": "older-meta"}]

assert actual[3]["prompt_moe_topk_indices"] == {"source": "newer"}


def test_normalize_request_messages_without_anchor_leaves_messages_unchanged():
messages = [
{
"role": "assistant",
"content": "older",
"prompt_moe_topk_indices": {"source": "older"},
},
{
"role": "assistant",
"content": "newer",
"generation_moe_topk_indices": [{"source": "newer"}],
},
]

expected = copy.deepcopy(messages)
actual = NemoGymClient._normalize_request_messages(messages)

assert actual == expected


def test_log_completion_writes_request_messages(tmp_path):
client = NemoGymClient.__new__(NemoGymClient)
client.llm = SimpleNamespace(
config=SimpleNamespace(
log_completions_folder=str(tmp_path),
model="test-model",
)
)

request_messages = [
{
"role": "assistant",
"content": "anchor",
"prompt_moe_topk_indices": [{"source": "older"}, {"source": "anchor"}],
}
]

client._log_completion(
messages=[Message(role="user", content=[TextContent(text="hello")])],
model_response_json={
"choices": [{"message": {"role": "assistant", "content": "done"}}]
},
provider_specific_fields={},
params={"messages": request_messages},
)

[log_file] = list(tmp_path.iterdir())
logged = json.loads(log_file.read_text())
assert logged["request_messages"] == request_messages