From 0bd055d77f15339d8306a93779a6eba8f294e21a Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 29 May 2026 15:32:03 -0700 Subject: [PATCH 1/2] Add per-turn payload merge for remote traces Co-authored-by: Cursor --- eval_protocol/pytest/tracing_utils.py | 50 ++- tests/manual/test_logprobs_e2e.py | 320 ++++++++++++++++++ tests/pytest/test_tracing_utils.py | 57 ++++ .../remote_server_two_turn_logprobs.py | 101 ++++++ 4 files changed, 527 insertions(+), 1 deletion(-) create mode 100644 tests/manual/test_logprobs_e2e.py create mode 100644 tests/pytest/test_tracing_utils.py create mode 100644 tests/remote_server/remote_server_two_turn_logprobs.py diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index f2a727db..279d1055 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -28,7 +28,55 @@ def fetch_traces() -> List[EvaluationRow]: include_payloads=config.include_payloads, ) - return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) + def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]: + filtered_rows = filter_longest_conversation(rows) + if config.include_payloads and filtered_rows: + _merge_payloads_into_longest_row(filtered_rows[0], rows) + return filtered_rows + + return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces) + + +def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None: + """ + Preserve per-turn payload-derived metadata after selecting the longest trace row. + + Each trace row carries payloads for its final assistant turn. The longest row + keeps the full conversation, while its top-level execution metadata remains + the payload metadata for the final completion for backward compatibility. + """ + target_assistants = longest_row.get_assistant_messages() + assistant_turn_payloads = [] + + for row in sorted(rows, key=lambda item: len(item.messages)): + source = row.last_assistant_message() + source_turn_index = len(row.get_assistant_messages()) - 1 + if source_turn_index < 0 or source_turn_index >= len(target_assistants): + continue + + if source and source.logprobs and not target_assistants[source_turn_index].logprobs: + target_assistants[source_turn_index].logprobs = source.logprobs + + extra = row.execution_metadata.extra or {} + turn_payload = { + key: extra[key] + for key in ( + "completion_logprobs", + "completion_token_ids", + "logprobs_metadata", + "routing_matrices", + "routing_metadata", + ) + if key in extra + } + if turn_payload: + turn_payload["assistant_turn_index"] = source_turn_index + assistant_turn_payloads.append(turn_payload) + + if assistant_turn_payloads: + if longest_row.execution_metadata.extra is None: + longest_row.execution_metadata.extra = {} + longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads def build_fireworks_tracing_url( diff --git a/tests/manual/test_logprobs_e2e.py b/tests/manual/test_logprobs_e2e.py new file mode 100644 index 00000000..0bd49197 --- /dev/null +++ b/tests/manual/test_logprobs_e2e.py @@ -0,0 +1,320 @@ +"""Minimal e2e test for logprobs trace payloads via RemoteRolloutProcessor. + +Spins up the reference remote server locally, which makes the LLM call +through litellm-gateway-dev. RemoteRolloutProcessor polls the dev gateway +and fetches traces with include_payloads=True. + +Run with: + cd eval-protocol-python-sdk + FIREWORKS_API_KEY="$FIREWORKS_DEV_API_KEY" \\ + pytest tests/manual/test_logprobs_e2e.py -v -s + +Requires gateway+consumer dev deploy with logprobs payload support and deployment: + accounts/pyroworks-dev/deployments/malaysia2-careful-paprika +""" + +import os +import socket +import subprocess +import sys +import time +from typing import List + +import pytest +import requests + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, EvaluateResult, Message, MetricResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + +DEPLOYMENT = "accounts/pyroworks-dev/deployments/malaysia2-careful-paprika" +GATEWAY_DEV_URL = "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app" +FIREWORKS_DEV_INFERENCE_BASE = "https://dev.api.fireworks.ai/inference/v1" + + +def _find_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +SERVER_PORT = _find_available_port() + + +def _wait_for_server(port: int, timeout: int = 30): + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"http://127.0.0.1:{port}") + return + except requests.exceptions.ConnectionError: + time.sleep(0.5) + raise TimeoutError(f"Remote server did not start within {timeout}s") + + +@pytest.fixture +def remote_server_module(request): + return getattr(request, "param", "tests.remote_server.remote_server") + + +@pytest.fixture(autouse=True) +def _remote_server(remote_server_module): + env = os.environ.copy() + env["FW_TRACING_GATEWAY_BASE_URL"] = GATEWAY_DEV_URL + api_key = os.environ.get("FIREWORKS_API_KEY") or os.environ.get("FIREWORKS_DEV_API_KEY") + if api_key: + env["FIREWORKS_API_KEY"] = api_key + proc = subprocess.Popen( + [ + sys.executable, + "-m", + remote_server_module, + "--host", + "127.0.0.1", + "--port", + str(SERVER_PORT), + ], + env=env, + ) + _wait_for_server(SERVER_PORT) + yield + proc.terminate() + proc.wait() + + +def input_rows() -> List[EvaluationRow]: + return [ + EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]), + ] + + +def two_turn_input_rows() -> List[EvaluationRow]: + return [ + EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]), + ] + + +def _logprobs_content(message: Message) -> list: + if not message.logprobs: + return [] + return message.logprobs.get("content") or [] + + +@pytest.mark.parametrize( + "completion_params", + [ + { + "model": DEPLOYMENT, + "logprobs": True, + "base_url": FIREWORKS_DEV_INFERENCE_BASE, + } + ], +) +@evaluation_test( + data_loaders=DynamicDataLoader(generators=[input_rows]), + rollout_processor=RemoteRolloutProcessor( + remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", + model_base_url=GATEWAY_DEV_URL, + include_payloads=True, + timeout_seconds=180, + ), +) +async def test_logprobs_present(row: EvaluationRow) -> EvaluationRow: + """Verify completion logprobs and Message.logprobs after remote rollout.""" + + has_response = len(row.messages) > 1 + assistant_msg = row.messages[-1] if has_response else None + + extra = row.execution_metadata.extra or {} + completion_logprobs = extra.get("completion_logprobs") or [] + has_completion_logprobs = len(completion_logprobs) > 0 + + message_content = None + if assistant_msg and assistant_msg.logprobs: + message_content = assistant_msg.logprobs.get("content") or [] + + has_message_logprobs = message_content is not None and len(message_content) > 0 + lengths_match = ( + has_completion_logprobs + and has_message_logprobs + and len(message_content) == len(completion_logprobs) + ) + + if has_completion_logprobs: + print( + f"\n Logprobs OK: {len(completion_logprobs)} completion tokens" + f" | message.content len={len(message_content) if message_content else 0}" + ) + else: + print(f"\n No logprobs in extra={extra}") + + score = 1.0 if (has_response and has_completion_logprobs and lengths_match) else 0.0 + reason_parts = [] + if not has_response: + reason_parts.append("no assistant response") + if not has_completion_logprobs: + reason_parts.append("no completion_logprobs in execution_metadata.extra") + if not lengths_match: + reason_parts.append( + f"message.logprobs content length ({len(message_content or [])}) " + f"!= completion_logprobs ({len(completion_logprobs)})" + ) + + reason = "All checks passed" if score == 1.0 else "; ".join(reason_parts) + + row.evaluation_result = EvaluateResult( + score=score, + reason=reason, + metrics={ + "has_response": MetricResult( + score=float(has_response), + is_score_valid=True, + reason="got response" if has_response else "no response", + ), + "has_completion_logprobs": MetricResult( + score=float(has_completion_logprobs), + is_score_valid=True, + reason="present" if has_completion_logprobs else "missing", + ), + "logprobs_lengths_match": MetricResult( + score=float(lengths_match), + is_score_valid=True, + reason="match" if lengths_match else "mismatch", + ), + }, + ) + + assert has_response, f"Expected assistant response. Messages: {row.messages}" + assert has_completion_logprobs, ( + f"Expected completion_logprobs in extra but got: {row.execution_metadata.extra}" + ) + assert lengths_match, ( + "Expected len(message.logprobs['content']) == len(completion_logprobs); " + f"got {len(message_content or [])} vs {len(completion_logprobs)}" + ) + + return row + + +@pytest.mark.parametrize( + "remote_server_module", + ["tests.remote_server.remote_server_two_turn_logprobs"], + indirect=True, +) +@pytest.mark.parametrize( + "completion_params", + [ + { + "model": DEPLOYMENT, + "logprobs": True, + "base_url": FIREWORKS_DEV_INFERENCE_BASE, + } + ], +) +@evaluation_test( + data_loaders=DynamicDataLoader(generators=[two_turn_input_rows]), + rollout_processor=RemoteRolloutProcessor( + remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", + model_base_url=GATEWAY_DEV_URL, + include_payloads=True, + timeout_seconds=180, + ), +) +async def test_two_turn_logprobs_present(row: EvaluationRow) -> EvaluationRow: + """Verify each assistant turn in a two-turn remote rollout has logprobs.""" + + roles = [message.role for message in row.messages] + assistant_messages = row.get_assistant_messages() + logprob_lengths = [len(_logprobs_content(message)) for message in assistant_messages] + + has_two_turn_shape = roles == ["user", "assistant", "user", "assistant"] + has_two_assistant_turns = len(assistant_messages) == 2 + all_turns_have_logprobs = has_two_assistant_turns and all(length > 0 for length in logprob_lengths) + + extra = row.execution_metadata.extra or {} + final_completion_logprobs = extra.get("completion_logprobs") or [] + assistant_turn_payloads = extra.get("assistant_turn_payloads") or [] + final_lengths_match = ( + has_two_assistant_turns + and len(final_completion_logprobs) > 0 + and len(final_completion_logprobs) == logprob_lengths[-1] + ) + has_payloads_for_each_turn = len(assistant_turn_payloads) == len(assistant_messages) + turn_payload_lengths_match = has_payloads_for_each_turn and all( + payload.get("assistant_turn_index") == idx + and len(payload.get("completion_logprobs") or []) == logprob_lengths[idx] + for idx, payload in enumerate(assistant_turn_payloads) + ) + + if all_turns_have_logprobs: + print(f"\n Two-turn logprobs OK: assistant token counts={logprob_lengths}") + else: + print(f"\n Missing two-turn logprobs: roles={roles} token_counts={logprob_lengths}") + + all_ok = ( + has_two_turn_shape + and all_turns_have_logprobs + and final_lengths_match + and turn_payload_lengths_match + ) + reason_parts = [] + if not has_two_turn_shape: + reason_parts.append(f"expected user/assistant/user/assistant roles but got {roles}") + if not has_two_assistant_turns: + reason_parts.append(f"expected 2 assistant turns but got {len(assistant_messages)}") + if has_two_assistant_turns and not all_turns_have_logprobs: + reason_parts.append(f"missing assistant logprobs; token_counts={logprob_lengths}") + if not final_lengths_match: + reason_parts.append( + "final assistant message logprobs length " + f"({logprob_lengths[-1] if logprob_lengths else 0}) " + f"!= completion_logprobs ({len(final_completion_logprobs)})" + ) + if not has_payloads_for_each_turn: + reason_parts.append(f"expected per-turn payloads for each assistant turn but got {assistant_turn_payloads}") + if has_payloads_for_each_turn and not turn_payload_lengths_match: + reason_parts.append(f"per-turn payload lengths do not match message logprobs: {assistant_turn_payloads}") + + row.evaluation_result = EvaluateResult( + score=1.0 if all_ok else 0.0, + reason="All checks passed" if all_ok else "; ".join(reason_parts), + metrics={ + "has_two_turn_shape": MetricResult( + score=float(has_two_turn_shape), + is_score_valid=True, + reason="match" if has_two_turn_shape else "unexpected roles", + ), + "all_turns_have_logprobs": MetricResult( + score=float(all_turns_have_logprobs), + is_score_valid=True, + reason="present" if all_turns_have_logprobs else "missing", + ), + "final_logprobs_lengths_match": MetricResult( + score=float(final_lengths_match), + is_score_valid=True, + reason="match" if final_lengths_match else "mismatch", + ), + "turn_payload_lengths_match": MetricResult( + score=float(turn_payload_lengths_match), + is_score_valid=True, + reason="match" if turn_payload_lengths_match else "mismatch", + ), + }, + ) + + assert has_two_turn_shape, f"Expected two-turn conversation but got roles: {roles}" + assert all_turns_have_logprobs, ( + "Expected logprobs on both assistant turns; " + f"token_counts={logprob_lengths}, messages={row.messages}" + ) + assert final_lengths_match, ( + "Expected final assistant logprobs to match completion_logprobs; " + f"got {logprob_lengths[-1] if logprob_lengths else 0} vs {len(final_completion_logprobs)}" + ) + assert turn_payload_lengths_match, ( + "Expected assistant_turn_payloads to match each assistant turn's logprobs; " + f"payloads={assistant_turn_payloads}, token_counts={logprob_lengths}" + ) + + return row diff --git a/tests/pytest/test_tracing_utils.py b/tests/pytest/test_tracing_utils.py new file mode 100644 index 00000000..58ec55c1 --- /dev/null +++ b/tests/pytest/test_tracing_utils.py @@ -0,0 +1,57 @@ +from eval_protocol.models import EvaluationRow, ExecutionMetadata, Message +from eval_protocol.pytest.tracing_utils import _merge_payloads_into_longest_row + + +def test_merge_payloads_into_longest_row_preserves_each_assistant_turn(): + first_turn_logprobs = {"content": [{"logprob": -0.1}, {"logprob": -0.2}]} + second_turn_logprobs = {"content": [{"logprob": -0.3}]} + first_turn = EvaluationRow( + messages=[ + Message(role="user", content="What is 2+2?"), + Message(role="assistant", content="4", logprobs=first_turn_logprobs), + ], + execution_metadata=ExecutionMetadata( + extra={ + "completion_logprobs": [-0.1, -0.2], + "routing_matrices": ["first-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ), + ) + second_turn = EvaluationRow( + messages=[ + Message(role="user", content="What is 2+2?"), + Message(role="assistant", content="4"), + Message(role="user", content="Use that in a sentence."), + Message(role="assistant", content="4", logprobs=second_turn_logprobs), + ], + execution_metadata=ExecutionMetadata( + extra={ + "completion_logprobs": [-0.3], + "routing_matrices": ["second-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ), + ) + + _merge_payloads_into_longest_row(second_turn, [first_turn, second_turn]) + + assistant_messages = second_turn.get_assistant_messages() + assert assistant_messages[0].logprobs == first_turn_logprobs + assert assistant_messages[1].logprobs == second_turn_logprobs + assert second_turn.execution_metadata.extra is not None + assert second_turn.execution_metadata.extra["routing_matrices"] == ["second-matrix"] + assert second_turn.execution_metadata.extra["assistant_turn_payloads"] == [ + { + "assistant_turn_index": 0, + "completion_logprobs": [-0.1, -0.2], + "routing_matrices": ["first-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + { + "assistant_turn_index": 1, + "completion_logprobs": [-0.3], + "routing_matrices": ["second-matrix"], + "routing_metadata": {"total_token_count": 1}, + }, + ] diff --git a/tests/remote_server/remote_server_two_turn_logprobs.py b/tests/remote_server/remote_server_two_turn_logprobs.py new file mode 100644 index 00000000..56b4541b --- /dev/null +++ b/tests/remote_server/remote_server_two_turn_logprobs.py @@ -0,0 +1,101 @@ +import argparse +import logging +import os +import threading + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI + +from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status + + +app = FastAPI() + +logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s") + +fireworks_handler = FireworksTracingHttpHandler() +logging.getLogger().addHandler(fireworks_handler) + + +def _clean_messages(messages): + clean_messages = [] + for message in messages: + if hasattr(message, "dump_mdoel_for_chat_completion_request"): + message_dict = message.dump_mdoel_for_chat_completion_request() + elif hasattr(message, "model_dump"): + message_dict = message.model_dump(exclude_none=True) + elif isinstance(message, dict): + message_dict = {key: value for key, value in message.items() if value is not None} + else: + message_dict = { + "role": getattr(message, "role", None), + "content": getattr(message, "content", None), + } + message_dict = {key: value for key, value in message_dict.items() if value is not None} + clean_messages.append(message_dict) + return clean_messages + + +@app.post("/init") +def init(req: InitRequest): + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + def _worker(): + try: + if not req.messages: + raise ValueError("messages is required") + + model = req.completion_params.get("model") + if not model: + raise ValueError("model is required in completion_params") + + completion_params = {key: value for key, value in req.completion_params.items() if key != "base_url"} + client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) + + conversation_history = _clean_messages(req.messages) + logger.info("Turn 1: sending completion request to model %s", model) + completion = client.chat.completions.create( + messages=conversation_history, + **completion_params, + ) + assistant_content = completion.choices[0].message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info("Turn 1 response: %s", assistant_content[:100]) + + follow_up = "Use that answer in one short sentence." + conversation_history.append({"role": "user", "content": follow_up}) + logger.info("Turn 2: user asks: %s", follow_up) + completion = client.chat.completions.create( + messages=conversation_history, + **completion_params, + ) + assistant_content = completion.choices[0].message.content or "" + logger.info("Turn 2 response: %s", assistant_content[:100]) + + except Exception as e: + logger.error("Error in rollout %s: %s", req.metadata.rollout_id, e) + pass + finally: + logger.info( + "Rollout %s completed", + req.metadata.rollout_id, + extra={"status": Status.rollout_finished()}, + ) + + thread = threading.Thread(target=_worker, daemon=True) + thread.start() + + +def main(): + parser = argparse.ArgumentParser(description="Run the two-turn logprobs remote server") + parser.add_argument("--host", default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")) + parser.add_argument("--port", type=int, default=int(os.getenv("REMOTE_SERVER_PORT", "3000"))) + args = parser.parse_args() + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() From c3690caad5a076eb00d9f168243cf18fa7381e81 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 29 May 2026 15:33:16 -0700 Subject: [PATCH 2/2] Remove manual logprobs e2e additions Co-authored-by: Cursor --- tests/manual/test_logprobs_e2e.py | 320 ------------------ .../remote_server_two_turn_logprobs.py | 101 ------ 2 files changed, 421 deletions(-) delete mode 100644 tests/manual/test_logprobs_e2e.py delete mode 100644 tests/remote_server/remote_server_two_turn_logprobs.py diff --git a/tests/manual/test_logprobs_e2e.py b/tests/manual/test_logprobs_e2e.py deleted file mode 100644 index 0bd49197..00000000 --- a/tests/manual/test_logprobs_e2e.py +++ /dev/null @@ -1,320 +0,0 @@ -"""Minimal e2e test for logprobs trace payloads via RemoteRolloutProcessor. - -Spins up the reference remote server locally, which makes the LLM call -through litellm-gateway-dev. RemoteRolloutProcessor polls the dev gateway -and fetches traces with include_payloads=True. - -Run with: - cd eval-protocol-python-sdk - FIREWORKS_API_KEY="$FIREWORKS_DEV_API_KEY" \\ - pytest tests/manual/test_logprobs_e2e.py -v -s - -Requires gateway+consumer dev deploy with logprobs payload support and deployment: - accounts/pyroworks-dev/deployments/malaysia2-careful-paprika -""" - -import os -import socket -import subprocess -import sys -import time -from typing import List - -import pytest -import requests - -from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.models import EvaluationRow, EvaluateResult, Message, MetricResult -from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor - -DEPLOYMENT = "accounts/pyroworks-dev/deployments/malaysia2-careful-paprika" -GATEWAY_DEV_URL = "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app" -FIREWORKS_DEV_INFERENCE_BASE = "https://dev.api.fireworks.ai/inference/v1" - - -def _find_available_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -SERVER_PORT = _find_available_port() - - -def _wait_for_server(port: int, timeout: int = 30): - start = time.time() - while time.time() - start < timeout: - try: - requests.get(f"http://127.0.0.1:{port}") - return - except requests.exceptions.ConnectionError: - time.sleep(0.5) - raise TimeoutError(f"Remote server did not start within {timeout}s") - - -@pytest.fixture -def remote_server_module(request): - return getattr(request, "param", "tests.remote_server.remote_server") - - -@pytest.fixture(autouse=True) -def _remote_server(remote_server_module): - env = os.environ.copy() - env["FW_TRACING_GATEWAY_BASE_URL"] = GATEWAY_DEV_URL - api_key = os.environ.get("FIREWORKS_API_KEY") or os.environ.get("FIREWORKS_DEV_API_KEY") - if api_key: - env["FIREWORKS_API_KEY"] = api_key - proc = subprocess.Popen( - [ - sys.executable, - "-m", - remote_server_module, - "--host", - "127.0.0.1", - "--port", - str(SERVER_PORT), - ], - env=env, - ) - _wait_for_server(SERVER_PORT) - yield - proc.terminate() - proc.wait() - - -def input_rows() -> List[EvaluationRow]: - return [ - EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]), - ] - - -def two_turn_input_rows() -> List[EvaluationRow]: - return [ - EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]), - ] - - -def _logprobs_content(message: Message) -> list: - if not message.logprobs: - return [] - return message.logprobs.get("content") or [] - - -@pytest.mark.parametrize( - "completion_params", - [ - { - "model": DEPLOYMENT, - "logprobs": True, - "base_url": FIREWORKS_DEV_INFERENCE_BASE, - } - ], -) -@evaluation_test( - data_loaders=DynamicDataLoader(generators=[input_rows]), - rollout_processor=RemoteRolloutProcessor( - remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", - model_base_url=GATEWAY_DEV_URL, - include_payloads=True, - timeout_seconds=180, - ), -) -async def test_logprobs_present(row: EvaluationRow) -> EvaluationRow: - """Verify completion logprobs and Message.logprobs after remote rollout.""" - - has_response = len(row.messages) > 1 - assistant_msg = row.messages[-1] if has_response else None - - extra = row.execution_metadata.extra or {} - completion_logprobs = extra.get("completion_logprobs") or [] - has_completion_logprobs = len(completion_logprobs) > 0 - - message_content = None - if assistant_msg and assistant_msg.logprobs: - message_content = assistant_msg.logprobs.get("content") or [] - - has_message_logprobs = message_content is not None and len(message_content) > 0 - lengths_match = ( - has_completion_logprobs - and has_message_logprobs - and len(message_content) == len(completion_logprobs) - ) - - if has_completion_logprobs: - print( - f"\n Logprobs OK: {len(completion_logprobs)} completion tokens" - f" | message.content len={len(message_content) if message_content else 0}" - ) - else: - print(f"\n No logprobs in extra={extra}") - - score = 1.0 if (has_response and has_completion_logprobs and lengths_match) else 0.0 - reason_parts = [] - if not has_response: - reason_parts.append("no assistant response") - if not has_completion_logprobs: - reason_parts.append("no completion_logprobs in execution_metadata.extra") - if not lengths_match: - reason_parts.append( - f"message.logprobs content length ({len(message_content or [])}) " - f"!= completion_logprobs ({len(completion_logprobs)})" - ) - - reason = "All checks passed" if score == 1.0 else "; ".join(reason_parts) - - row.evaluation_result = EvaluateResult( - score=score, - reason=reason, - metrics={ - "has_response": MetricResult( - score=float(has_response), - is_score_valid=True, - reason="got response" if has_response else "no response", - ), - "has_completion_logprobs": MetricResult( - score=float(has_completion_logprobs), - is_score_valid=True, - reason="present" if has_completion_logprobs else "missing", - ), - "logprobs_lengths_match": MetricResult( - score=float(lengths_match), - is_score_valid=True, - reason="match" if lengths_match else "mismatch", - ), - }, - ) - - assert has_response, f"Expected assistant response. Messages: {row.messages}" - assert has_completion_logprobs, ( - f"Expected completion_logprobs in extra but got: {row.execution_metadata.extra}" - ) - assert lengths_match, ( - "Expected len(message.logprobs['content']) == len(completion_logprobs); " - f"got {len(message_content or [])} vs {len(completion_logprobs)}" - ) - - return row - - -@pytest.mark.parametrize( - "remote_server_module", - ["tests.remote_server.remote_server_two_turn_logprobs"], - indirect=True, -) -@pytest.mark.parametrize( - "completion_params", - [ - { - "model": DEPLOYMENT, - "logprobs": True, - "base_url": FIREWORKS_DEV_INFERENCE_BASE, - } - ], -) -@evaluation_test( - data_loaders=DynamicDataLoader(generators=[two_turn_input_rows]), - rollout_processor=RemoteRolloutProcessor( - remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", - model_base_url=GATEWAY_DEV_URL, - include_payloads=True, - timeout_seconds=180, - ), -) -async def test_two_turn_logprobs_present(row: EvaluationRow) -> EvaluationRow: - """Verify each assistant turn in a two-turn remote rollout has logprobs.""" - - roles = [message.role for message in row.messages] - assistant_messages = row.get_assistant_messages() - logprob_lengths = [len(_logprobs_content(message)) for message in assistant_messages] - - has_two_turn_shape = roles == ["user", "assistant", "user", "assistant"] - has_two_assistant_turns = len(assistant_messages) == 2 - all_turns_have_logprobs = has_two_assistant_turns and all(length > 0 for length in logprob_lengths) - - extra = row.execution_metadata.extra or {} - final_completion_logprobs = extra.get("completion_logprobs") or [] - assistant_turn_payloads = extra.get("assistant_turn_payloads") or [] - final_lengths_match = ( - has_two_assistant_turns - and len(final_completion_logprobs) > 0 - and len(final_completion_logprobs) == logprob_lengths[-1] - ) - has_payloads_for_each_turn = len(assistant_turn_payloads) == len(assistant_messages) - turn_payload_lengths_match = has_payloads_for_each_turn and all( - payload.get("assistant_turn_index") == idx - and len(payload.get("completion_logprobs") or []) == logprob_lengths[idx] - for idx, payload in enumerate(assistant_turn_payloads) - ) - - if all_turns_have_logprobs: - print(f"\n Two-turn logprobs OK: assistant token counts={logprob_lengths}") - else: - print(f"\n Missing two-turn logprobs: roles={roles} token_counts={logprob_lengths}") - - all_ok = ( - has_two_turn_shape - and all_turns_have_logprobs - and final_lengths_match - and turn_payload_lengths_match - ) - reason_parts = [] - if not has_two_turn_shape: - reason_parts.append(f"expected user/assistant/user/assistant roles but got {roles}") - if not has_two_assistant_turns: - reason_parts.append(f"expected 2 assistant turns but got {len(assistant_messages)}") - if has_two_assistant_turns and not all_turns_have_logprobs: - reason_parts.append(f"missing assistant logprobs; token_counts={logprob_lengths}") - if not final_lengths_match: - reason_parts.append( - "final assistant message logprobs length " - f"({logprob_lengths[-1] if logprob_lengths else 0}) " - f"!= completion_logprobs ({len(final_completion_logprobs)})" - ) - if not has_payloads_for_each_turn: - reason_parts.append(f"expected per-turn payloads for each assistant turn but got {assistant_turn_payloads}") - if has_payloads_for_each_turn and not turn_payload_lengths_match: - reason_parts.append(f"per-turn payload lengths do not match message logprobs: {assistant_turn_payloads}") - - row.evaluation_result = EvaluateResult( - score=1.0 if all_ok else 0.0, - reason="All checks passed" if all_ok else "; ".join(reason_parts), - metrics={ - "has_two_turn_shape": MetricResult( - score=float(has_two_turn_shape), - is_score_valid=True, - reason="match" if has_two_turn_shape else "unexpected roles", - ), - "all_turns_have_logprobs": MetricResult( - score=float(all_turns_have_logprobs), - is_score_valid=True, - reason="present" if all_turns_have_logprobs else "missing", - ), - "final_logprobs_lengths_match": MetricResult( - score=float(final_lengths_match), - is_score_valid=True, - reason="match" if final_lengths_match else "mismatch", - ), - "turn_payload_lengths_match": MetricResult( - score=float(turn_payload_lengths_match), - is_score_valid=True, - reason="match" if turn_payload_lengths_match else "mismatch", - ), - }, - ) - - assert has_two_turn_shape, f"Expected two-turn conversation but got roles: {roles}" - assert all_turns_have_logprobs, ( - "Expected logprobs on both assistant turns; " - f"token_counts={logprob_lengths}, messages={row.messages}" - ) - assert final_lengths_match, ( - "Expected final assistant logprobs to match completion_logprobs; " - f"got {logprob_lengths[-1] if logprob_lengths else 0} vs {len(final_completion_logprobs)}" - ) - assert turn_payload_lengths_match, ( - "Expected assistant_turn_payloads to match each assistant turn's logprobs; " - f"payloads={assistant_turn_payloads}, token_counts={logprob_lengths}" - ) - - return row diff --git a/tests/remote_server/remote_server_two_turn_logprobs.py b/tests/remote_server/remote_server_two_turn_logprobs.py deleted file mode 100644 index 56b4541b..00000000 --- a/tests/remote_server/remote_server_two_turn_logprobs.py +++ /dev/null @@ -1,101 +0,0 @@ -import argparse -import logging -import os -import threading - -import uvicorn -from fastapi import FastAPI -from openai import OpenAI - -from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status - - -app = FastAPI() - -logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s") - -fireworks_handler = FireworksTracingHttpHandler() -logging.getLogger().addHandler(fireworks_handler) - - -def _clean_messages(messages): - clean_messages = [] - for message in messages: - if hasattr(message, "dump_mdoel_for_chat_completion_request"): - message_dict = message.dump_mdoel_for_chat_completion_request() - elif hasattr(message, "model_dump"): - message_dict = message.model_dump(exclude_none=True) - elif isinstance(message, dict): - message_dict = {key: value for key, value in message.items() if value is not None} - else: - message_dict = { - "role": getattr(message, "role", None), - "content": getattr(message, "content", None), - } - message_dict = {key: value for key, value in message_dict.items() if value is not None} - clean_messages.append(message_dict) - return clean_messages - - -@app.post("/init") -def init(req: InitRequest): - logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") - logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) - - def _worker(): - try: - if not req.messages: - raise ValueError("messages is required") - - model = req.completion_params.get("model") - if not model: - raise ValueError("model is required in completion_params") - - completion_params = {key: value for key, value in req.completion_params.items() if key != "base_url"} - client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) - - conversation_history = _clean_messages(req.messages) - logger.info("Turn 1: sending completion request to model %s", model) - completion = client.chat.completions.create( - messages=conversation_history, - **completion_params, - ) - assistant_content = completion.choices[0].message.content or "" - conversation_history.append({"role": "assistant", "content": assistant_content}) - logger.info("Turn 1 response: %s", assistant_content[:100]) - - follow_up = "Use that answer in one short sentence." - conversation_history.append({"role": "user", "content": follow_up}) - logger.info("Turn 2: user asks: %s", follow_up) - completion = client.chat.completions.create( - messages=conversation_history, - **completion_params, - ) - assistant_content = completion.choices[0].message.content or "" - logger.info("Turn 2 response: %s", assistant_content[:100]) - - except Exception as e: - logger.error("Error in rollout %s: %s", req.metadata.rollout_id, e) - pass - finally: - logger.info( - "Rollout %s completed", - req.metadata.rollout_id, - extra={"status": Status.rollout_finished()}, - ) - - thread = threading.Thread(target=_worker, daemon=True) - thread.start() - - -def main(): - parser = argparse.ArgumentParser(description="Run the two-turn logprobs remote server") - parser.add_argument("--host", default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")) - parser.add_argument("--port", type=int, default=int(os.getenv("REMOTE_SERVER_PORT", "3000"))) - args = parser.parse_args() - - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main()