diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 45fc2697..62a632e6 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -16,6 +16,8 @@ from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter +from .lp_deserializer import decompress_and_parse_lp +from .r3_deserializer import decompress_and_parse_r3 from .utils import extract_messages_from_data from ..common_utils import get_user_agent @@ -106,8 +108,6 @@ def convert_trace_dict_to_evaluation_row( router_replay = payloads.get("router_replay") if isinstance(router_replay, dict) and router_replay.get("data"): try: - from .r3_deserializer import decompress_and_parse_r3 - matrices, r3_meta = decompress_and_parse_r3(router_replay["data"]) if execution_metadata.extra is None: execution_metadata.extra = {} @@ -116,6 +116,32 @@ def convert_trace_dict_to_evaluation_row( except Exception as e: logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e) + logprobs_payload = payloads.get("logprobs") + if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"): + try: + logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"]) + if execution_metadata.extra is None: + execution_metadata.extra = {} + execution_metadata.extra["completion_logprobs"] = logprobs + if token_ids is not None: + execution_metadata.extra["completion_token_ids"] = token_ids + execution_metadata.extra["logprobs_metadata"] = lp_meta + + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "assistant": + content_entries = [{"logprob": lp} for lp in logprobs] + if token_ids is not None: + for entry, tid in zip(content_entries, token_ids): + entry["token_id"] = tid + messages[i].logprobs = {"content": content_entries} + break + except Exception as e: + logger.warning( + "Failed to decompress logprobs payload for trace %s: %s", + trace.get("id"), + e, + ) + return EvaluationRow( messages=messages, tools=tools, diff --git a/eval_protocol/adapters/lp_deserializer.py b/eval_protocol/adapters/lp_deserializer.py new file mode 100644 index 00000000..57aa4f46 --- /dev/null +++ b/eval_protocol/adapters/lp_deserializer.py @@ -0,0 +1,109 @@ +"""LP/v1 binary deserializer for per-token logprobs payloads. + +Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``. +See that module for the full header specification. +""" + +from __future__ import annotations + +import base64 +import struct +from typing import Any, Dict, List, Optional, Tuple + +import zstandard as zstd + +MAGIC = b"LP01" +HEADER_VERSION = 1 +MISSING_TOKEN_ID = -1 +ENTRY_FORMAT = " Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}") + + ( + magic, + version, + flags, + reserved_u16, + token_count, + body_byte_length, + reserved_u64, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad LP/v1 magic: {magic!r}") + if version != HEADER_VERSION: + raise ValueError(f"Unsupported lp/v1 header version: {version}") + + return { + "flags": flags, + "reserved_u16": reserved_u16, + "token_count": token_count, + "body_byte_length": body_byte_length, + "reserved_u64": reserved_u64, + } + + +def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: + """Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata.""" + header = _parse_header(raw) + token_count = header["token_count"] + body_byte_length = header["body_byte_length"] + + if token_count == 0: + raise ValueError("LP/v1 token_count must be > 0") + if body_byte_length != token_count * ENTRY_SIZE: + raise ValueError( + f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} " + f"({token_count * ENTRY_SIZE})" + ) + + expected_len = HEADER_SIZE + body_byte_length + if len(raw) != expected_len: + raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}") + + logprobs: List[float] = [] + token_ids: List[int] = [] + all_token_ids_valid = True + offset = HEADER_SIZE + for _ in range(token_count): + wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE]) + offset += ENTRY_SIZE + logprobs.append(logprob) + if wire_id == MISSING_TOKEN_ID: + all_token_ids_valid = False + token_ids.append(wire_id) + else: + token_ids.append(wire_id) + + metadata: Dict[str, Any] = { + "scope": "completion_only", + "completion_token_count": token_count, + "all_token_ids_valid": all_token_ids_valid, + } + header.update(metadata) + ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None + return logprobs, ids_out, header + + +def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]: + """Decompress and unpack an LP/v1 payload into completion logprobs and token ids. + + Args: + data_b64: Base64-encoded zstd-compressed LP binary blob from + ``payloads.logprobs.data``. + + Returns: + ``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token + scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``, + and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``. + """ + compressed = base64.b64decode(data_b64) + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed) + return parse_logprobs(raw) diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 1bbb4824..f2a727db 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -103,7 +103,7 @@ def build_init_request( if not completion_params_dict.get("model"): raise ValueError("Model must be provided in completion_params") - # Extract base_url from completion_params + # Extract base_url from completion_params for tracing-gateway URL encoding completion_params_base_url: Optional[str] = completion_params_dict.get("base_url") # Strip non-OpenAI fields from messages diff --git a/tests/adapters/test_fireworks_tracing_logprobs.py b/tests/adapters/test_fireworks_tracing_logprobs.py new file mode 100644 index 00000000..08dab60b --- /dev/null +++ b/tests/adapters/test_fireworks_tracing_logprobs.py @@ -0,0 +1,93 @@ +"""Tests for logprobs payload handling in fireworks_tracing adapter.""" + +from __future__ import annotations + +import base64 +import struct + +import pytest +import zstandard as zstd + +pytest.importorskip("mcp") + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.adapters.lp_deserializer import ( + ENTRY_FORMAT, + ENTRY_SIZE, + HEADER_FORMAT, + MAGIC, + MISSING_TOKEN_ID, +) + + +def _lp_b64(tokens: list[tuple[int, float]]) -> str: + token_count = len(tokens) + body_byte_length = token_count * ENTRY_SIZE + header = struct.pack( + HEADER_FORMAT, + MAGIC, + 1, + 0, + 0, + token_count, + body_byte_length, + 0, + ) + body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens) + raw = header + body + compressed = zstd.ZstdCompressor().compress(raw) + return base64.b64encode(compressed).decode("ascii") + + +def _base_trace(*, with_token_ids: bool = True) -> dict: + tokens = [(10, -0.1), (11, -0.2)] if with_token_ids else [(MISSING_TOKEN_ID, -0.1), (12, -0.2)] + return { + "id": "trace-1", + "input": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, + "output": {"role": "assistant", "content": "hello"}, + "payloads": { + "logprobs": { + "data": _lp_b64(tokens), + "manifest": {"PayloadVersion": "lp/v1"}, + }, + }, + } + + +class TestConvertTraceLogprobs: + def test_attaches_completion_logprobs_and_message_logprobs(self): + row = convert_trace_dict_to_evaluation_row(_base_trace()) + assert row is not None + + extra = row.execution_metadata.extra + assert extra is not None + assert extra["completion_logprobs"] == pytest.approx([-0.1, -0.2]) + assert extra["completion_token_ids"] == [10, 11] + + assistant = row.messages[-1] + assert assistant.role == "assistant" + content = assistant.logprobs["content"] + assert len(content) == len(extra["completion_logprobs"]) + assert content[0]["token_id"] == 10 + assert content[1]["token_id"] == 11 + assert content[0]["logprob"] == pytest.approx(-0.1) + assert content[1]["logprob"] == pytest.approx(-0.2) + + def test_omits_token_id_keys_when_any_missing(self): + row = convert_trace_dict_to_evaluation_row(_base_trace(with_token_ids=False)) + assert row is not None + + extra = row.execution_metadata.extra + assert "completion_logprobs" in extra + assert "completion_token_ids" not in extra + + content = row.messages[-1].logprobs["content"] + assert len(content) == 2 + assert all("token_id" not in entry for entry in content) + assert content[0]["logprob"] == pytest.approx(-0.1) + assert content[1]["logprob"] == pytest.approx(-0.2) diff --git a/tests/adapters/test_lp_deserializer.py b/tests/adapters/test_lp_deserializer.py new file mode 100644 index 00000000..52e04417 --- /dev/null +++ b/tests/adapters/test_lp_deserializer.py @@ -0,0 +1,78 @@ +"""Tests for LP/v1 binary deserializer (gateway-compatible).""" + +from __future__ import annotations + +import base64 +import struct + +import pytest +import zstandard as zstd + +from eval_protocol.adapters.lp_deserializer import ( + ENTRY_FORMAT, + ENTRY_SIZE, + HEADER_FORMAT, + HEADER_SIZE, + MAGIC, + MISSING_TOKEN_ID, + decompress_and_parse_lp, + parse_logprobs, +) + +# Golden raw bytes: two tokens (7, -0.25) and (8, -0.5) — must match gateway serializer. +GOLDEN_RAW_HEX = ( + "4c503031010000000200000010000000000000000000000007000000000080be" + "08000000000000bf" +) + + +def _build_raw(tokens: list[tuple[int, float]]) -> bytes: + token_count = len(tokens) + body_byte_length = token_count * ENTRY_SIZE + header = struct.pack( + HEADER_FORMAT, + MAGIC, + 1, + 0, + 0, + token_count, + body_byte_length, + 0, + ) + body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens) + return header + body + + +def _compress_b64(raw: bytes) -> str: + return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii") + + +class TestParseLogprobs: + def test_golden_bytes_match_gateway(self): + raw = bytes.fromhex(GOLDEN_RAW_HEX) + logprobs, token_ids, meta = parse_logprobs(raw) + assert logprobs == [-0.25, -0.5] + assert token_ids == [7, 8] + assert meta["all_token_ids_valid"] is True + assert meta["token_count"] == 2 + + def test_missing_token_id_omits_token_ids_list(self): + raw = _build_raw([(MISSING_TOKEN_ID, -0.3), (42, -0.4)]) + logprobs, token_ids, meta = parse_logprobs(raw) + assert logprobs == pytest.approx([-0.3, -0.4]) + assert token_ids is None + assert meta["all_token_ids_valid"] is False + + def test_decompress_and_parse_round_trip(self): + raw = bytes.fromhex(GOLDEN_RAW_HEX) + b64 = _compress_b64(raw) + logprobs, token_ids, meta = decompress_and_parse_lp(b64) + assert logprobs == [-0.25, -0.5] + assert token_ids == [7, 8] + assert meta["scope"] == "completion_only" + + def test_rejects_bad_magic(self): + raw = _build_raw([(1, -0.1)]) + bad = b"XXXX" + raw[4:] + with pytest.raises(ValueError, match="Bad LP/v1 magic"): + parse_logprobs(bad) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index c7655671..452dd13c 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -55,8 +55,12 @@ def _worker(): md = {k: v for k, v in md.items() if v is not None} messages_payload.append(md) - # Spread all completion_params (model, temperature, max_tokens, etc.) - completion_kwargs = {"messages": messages_payload, **req.completion_params} + # Spread completion_params; omit base_url (client uses req.model_base_url; gateway + # encodes inference base_url into the tracing path via build_init_request). + completion_kwargs = { + "messages": messages_payload, + **{k: v for k, v in req.completion_params.items() if k != "base_url"}, + } if req.tools: completion_kwargs["tools"] = req.tools