Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions openhands/agenthub/codeact_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,19 @@ def response_to_actions(
)
)

# Add response id to actions
# This will ensure we can match both actions without tool calls (e.g. MessageAction)
# and actions with tool calls (e.g. CmdRunAction, IPythonRunCellAction, etc.)
# with the token usage data
# Add response id and provider-specific fields to actions
# Extract provider_specific_fields from the response if available
provider_specific_fields = getattr(response, '_provider_specific_fields', {})

for action in actions:
action.response_id = response.id
# Set provider-specific fields if they exist
if 'prompt_token_ids' in provider_specific_fields:
action.prompt_token_ids = provider_specific_fields['prompt_token_ids']
if 'generation_token_ids' in provider_specific_fields:
action.generation_token_ids = provider_specific_fields['generation_token_ids']
if 'generation_log_probs' in provider_specific_fields:
action.generation_log_probs = provider_specific_fields['generation_log_probs']

assert len(actions) >= 1
return actions
15 changes: 11 additions & 4 deletions openhands/agenthub/loc_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,19 @@ def response_to_actions(
)
)

# Add response id to actions
# This will ensure we can match both actions without tool calls (e.g. MessageAction)
# and actions with tool calls (e.g. CmdRunAction, IPythonRunCellAction, etc.)
# with the token usage data
# Add response id and provider-specific fields to actions
# Extract provider_specific_fields from the response if available
provider_specific_fields = getattr(response, '_provider_specific_fields', {})

for action in actions:
action.response_id = response.id
# Set provider-specific fields if they exist
if 'prompt_token_ids' in provider_specific_fields:
action.prompt_token_ids = provider_specific_fields['prompt_token_ids']
if 'generation_token_ids' in provider_specific_fields:
action.generation_token_ids = provider_specific_fields['generation_token_ids']
if 'generation_log_probs' in provider_specific_fields:
action.generation_log_probs = provider_specific_fields['generation_log_probs']

assert len(actions) >= 1
return actions
Expand Down
15 changes: 11 additions & 4 deletions openhands/agenthub/readonly_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,19 @@ def response_to_actions(
)
)

# Add response id to actions
# This will ensure we can match both actions without tool calls (e.g. MessageAction)
# and actions with tool calls (e.g. CmdRunAction, IPythonRunCellAction, etc.)
# with the token usage data
# Add response id and provider-specific fields to actions
# Extract provider_specific_fields from the response if available
provider_specific_fields = getattr(response, '_provider_specific_fields', {})

for action in actions:
action.response_id = response.id
# Set provider-specific fields if they exist
if 'prompt_token_ids' in provider_specific_fields:
action.prompt_token_ids = provider_specific_fields['prompt_token_ids']
if 'generation_token_ids' in provider_specific_fields:
action.generation_token_ids = provider_specific_fields['generation_token_ids']
if 'generation_log_probs' in provider_specific_fields:
action.generation_log_probs = provider_specific_fields['generation_log_probs']

assert len(actions) >= 1
return actions
Expand Down
31 changes: 31 additions & 0 deletions openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,34 @@ def response_id(self) -> str | None:
@response_id.setter
def response_id(self, value: str) -> None:
self._response_id = value

# optional fields, provider-specific data from LLM response
@property
def prompt_token_ids(self) -> list[int] | None:
if hasattr(self, '_prompt_token_ids'):
return self._prompt_token_ids # type: ignore[attr-defined]
return None

@prompt_token_ids.setter
def prompt_token_ids(self, value: list[int]) -> None:
self._prompt_token_ids = value

@property
def generation_token_ids(self) -> list[int] | None:
if hasattr(self, '_generation_token_ids'):
return self._generation_token_ids # type: ignore[attr-defined]
return None

@generation_token_ids.setter
def generation_token_ids(self, value: list[int]) -> None:
self._generation_token_ids = value

@property
def generation_log_probs(self) -> list[float] | None:
if hasattr(self, '_generation_log_probs'):
return self._generation_log_probs # type: ignore[attr-defined]
return None

@generation_log_probs.setter
def generation_log_probs(self, value: list[float]) -> None:
self._generation_log_probs = value
6 changes: 5 additions & 1 deletion openhands/events/serialization/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
'cause',
'tool_call_metadata',
'llm_metrics',
'response_id',
'prompt_token_ids',
'generation_token_ids',
'generation_log_probs',
]

DELETE_FROM_TRAJECTORY_EXTRAS = {
Expand Down Expand Up @@ -71,7 +75,7 @@ def event_from_dict(data: dict[str, Any]) -> 'Event':
model_response_dict = value['model_response']
if isinstance(model_response_dict, dict) and 'provider_specific_fields' in model_response_dict:
provider_specific_fields = model_response_dict.pop('provider_specific_fields')

value = ToolCallMetadata(**value)

# Add provider_specific_fields back to the model_response
Expand Down
19 changes: 3 additions & 16 deletions openhands/memory/conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,26 +327,13 @@ def _process_action(
if role not in ('user', 'system', 'assistant', 'tool'):
raise ValueError(f'Invalid role: {role}')

# Extract provider_specific_fields if available (for assistant messages)
provider_specific_fields = {}
if role == 'assistant' and action.tool_call_metadata is not None:
provider_specific_fields = getattr(
action.tool_call_metadata.model_response,
'_provider_specific_fields',
{},
)

return [
Message(
role=role, # type: ignore[arg-type]
content=content,
prompt_token_ids=provider_specific_fields.get('prompt_token_ids'),
generation_token_ids=provider_specific_fields.get(
'generation_token_ids'
),
generation_log_probs=provider_specific_fields.get(
'generation_log_probs'
),
prompt_token_ids=action.prompt_token_ids,
generation_token_ids=action.generation_token_ids,
generation_log_probs=action.generation_log_probs,
)
]
elif isinstance(action, CmdRunAction) and action.source == 'user':
Expand Down
Loading