diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index f2a727db..279d1055 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -28,7 +28,55 @@ def fetch_traces() -> List[EvaluationRow]: include_payloads=config.include_payloads, ) - return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) + def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]: + filtered_rows = filter_longest_conversation(rows) + if config.include_payloads and filtered_rows: + _merge_payloads_into_longest_row(filtered_rows[0], rows) + return filtered_rows + + return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces) + + +def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None: + """ + Preserve per-turn payload-derived metadata after selecting the longest trace row. + + Each trace row carries payloads for its final assistant turn. The longest row + keeps the full conversation, while its top-level execution metadata remains + the payload metadata for the final completion for backward compatibility. + """ + target_assistants = longest_row.get_assistant_messages() + assistant_turn_payloads = [] + + for row in sorted(rows, key=lambda item: len(item.messages)): + source = row.last_assistant_message() + source_turn_index = len(row.get_assistant_messages()) - 1 + if source_turn_index < 0 or source_turn_index >= len(target_assistants): + continue + + if source and source.logprobs and not target_assistants[source_turn_index].logprobs: + target_assistants[source_turn_index].logprobs = source.logprobs + + extra = row.execution_metadata.extra or {} + turn_payload = { + key: extra[key] + for key in ( + "completion_logprobs", + "completion_token_ids", + "logprobs_metadata", + "routing_matrices", + "routing_metadata", + ) + if key in extra + } + if turn_payload: + turn_payload["assistant_turn_index"] = source_turn_index + assistant_turn_payloads.append(turn_payload) + + if assistant_turn_payloads: + if longest_row.execution_metadata.extra is None: + longest_row.execution_metadata.extra = {} + longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads def build_fireworks_tracing_url( diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py new file mode 100644 index 00000000..58ec55c1 --- /dev/null +++ b/tests/pytest/test_tracing_utils.py @@ -0,0 +1,57 @@ +from eval_protocol.models import EvaluationRow, ExecutionMetadata, Message +from eval_protocol.pytest.tracing_utils import _merge_payloads_into_longest_row + + +def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): + first_turn_logprobs = {"content": [{"logprob": -0.1}, {"logprob": -0.2}]} + second_turn_logprobs = {"content": [{"logprob": -0.3}]} + first_turn = EvaluationRow( + messages=[ + Message(role="user", content="What is 2+2?"), + Message(role="assistant", content="4", logprobs=first_turn_logprobs), + ], + execution_metadata=ExecutionMetadata( + extra={ + "completion_logprobs": [-0.1, -0.2], + "routing_matrices": ["first-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ), + ) + second_turn = EvaluationRow( + messages=[ + Message(role="user", content="What is 2+2?"), + Message(role="assistant", content="4"), + Message(role="user", content="Use that in a sentence."), + Message(role="assistant", content="4", logprobs=second_turn_logprobs), + ], + execution_metadata=ExecutionMetadata( + extra={ + "completion_logprobs": [-0.3], + "routing_matrices": ["second-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ), + ) + + _merge_payloads_into_longest_row(second_turn, [first_turn, second_turn]) + + assistant_messages = second_turn.get_assistant_messages() + assert assistant_messages[0].logprobs == first_turn_logprobs + assert assistant_messages[1].logprobs == second_turn_logprobs + assert second_turn.execution_metadata.extra is not None + assert second_turn.execution_metadata.extra["routing_matrices"] == ["second-matrix"] + assert second_turn.execution_metadata.extra["assistant_turn_payloads"] == [ + { + "assistant_turn_index": 0, + "completion_logprobs": [-0.1, -0.2], + "routing_matrices": ["first-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + { + "assistant_turn_index": 1, + "completion_logprobs": [-0.3], + "routing_matrices": ["second-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ]