Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,55 @@ def fetch_traces() -> List[EvaluationRow]:
include_payloads=config.include_payloads,
)

return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]:
filtered_rows = filter_longest_conversation(rows)
if config.include_payloads and filtered_rows:
_merge_payloads_into_longest_row(filtered_rows[0], rows)
return filtered_rows

return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces)


def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None:
"""
Preserve per-turn payload-derived metadata after selecting the longest trace row.

Each trace row carries payloads for its final assistant turn. The longest row
keeps the full conversation, while its top-level execution metadata remains
the payload metadata for the final completion for backward compatibility.
"""
target_assistants = longest_row.get_assistant_messages()
assistant_turn_payloads = []

for row in sorted(rows, key=lambda item: len(item.messages)):
source = row.last_assistant_message()
source_turn_index = len(row.get_assistant_messages()) - 1
if source_turn_index < 0 or source_turn_index >= len(target_assistants):
continue

if source and source.logprobs and not target_assistants[source_turn_index].logprobs:
target_assistants[source_turn_index].logprobs = source.logprobs
Comment thread
SunnySoldier357 marked this conversation as resolved.

extra = row.execution_metadata.extra or {}
turn_payload = {
key: extra[key]
for key in (
"completion_logprobs",
"completion_token_ids",
"logprobs_metadata",
"routing_matrices",
"routing_metadata",
)
if key in extra
}
if turn_payload:
turn_payload["assistant_turn_index"] = source_turn_index
assistant_turn_payloads.append(turn_payload)

if assistant_turn_payloads:
if longest_row.execution_metadata.extra is None:
longest_row.execution_metadata.extra = {}
longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads


def build_fireworks_tracing_url(
Expand Down
57 changes: 57 additions & 0 deletions tests/pytest/test_tracing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from eval_protocol.models import EvaluationRow, ExecutionMetadata, Message
from eval_protocol.pytest.tracing_utils import _merge_payloads_into_longest_row


def test_merge_payloads_into_longest_row_preserves_each_assistant_turn():
first_turn_logprobs = {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}
second_turn_logprobs = {"content": [{"logprob": -0.3}]}
first_turn = EvaluationRow(
messages=[
Message(role="user", content="What is 2+2?"),
Message(role="assistant", content="4", logprobs=first_turn_logprobs),
],
execution_metadata=ExecutionMetadata(
extra={
"completion_logprobs": [-0.1, -0.2],
"routing_matrices": ["first-matrix"],
"routing_metadata": {"total_token_count": 1},
},
),
)
second_turn = EvaluationRow(
messages=[
Message(role="user", content="What is 2+2?"),
Message(role="assistant", content="4"),
Message(role="user", content="Use that in a sentence."),
Message(role="assistant", content="4", logprobs=second_turn_logprobs),
],
execution_metadata=ExecutionMetadata(
extra={
"completion_logprobs": [-0.3],
"routing_matrices": ["second-matrix"],
"routing_metadata": {"total_token_count": 1},
},
),
)

_merge_payloads_into_longest_row(second_turn, [first_turn, second_turn])

assistant_messages = second_turn.get_assistant_messages()
assert assistant_messages[0].logprobs == first_turn_logprobs
assert assistant_messages[1].logprobs == second_turn_logprobs
assert second_turn.execution_metadata.extra is not None
assert second_turn.execution_metadata.extra["routing_matrices"] == ["second-matrix"]
assert second_turn.execution_metadata.extra["assistant_turn_payloads"] == [
{
"assistant_turn_index": 0,
"completion_logprobs": [-0.1, -0.2],
"routing_matrices": ["first-matrix"],
"routing_metadata": {"total_token_count": 1},
},
{
"assistant_turn_index": 1,
"completion_logprobs": [-0.3],
"routing_matrices": ["second-matrix"],
"routing_metadata": {"total_token_count": 1},
},
]
Loading