diff --git a/multistage_refs.json b/multistage_refs.json new file mode 100644 index 0000000000..e6e383192c --- /dev/null +++ b/multistage_refs.json @@ -0,0 +1,7 @@ +{ + "glm51": 1259, + "minimax_m27": 1165, + "nemotron3_ultra_ga": 1168, + "kimi_k25": 1000, + "qwen35_397b": 956 +} diff --git a/resources_servers/gdpval/app.py b/resources_servers/gdpval/app.py index 18463586ab..6da29d77b2 100644 --- a/resources_servers/gdpval/app.py +++ b/resources_servers/gdpval/app.py @@ -184,6 +184,12 @@ class GDPValVerifyRequest(BaseVerifyRequest): rubric_pretty: Optional[str] = None reference_file_urls: Optional[List[str]] = None deliverables_dir: Optional[str] = None + # Optional per-request filter (comparison mode): judge the eval deliverable + # only against this subset of the configured ``reference_models``. Unknown + # ids are ignored; ``None`` (default) judges against every configured + # reference. Used by the multi-stage ELO driver to select a different set of + # reference models per judgementstage without reconfiguring the server. + reference_ids: Optional[List[str]] = None class GDPValVerifyResponse(GDPValVerifyRequest, BaseVerifyResponse): @@ -369,11 +375,18 @@ async def _verify_comparison(self, body: GDPValVerifyRequest) -> GDPValVerifyRes eval_task_dir = Path(body.deliverables_dir) if body.deliverables_dir else None + # Optional per-request reference subset (multi-stage ELO). When set, only + # the named references are judged this call; unknown ids are ignored. + active_references = self._references + if body.reference_ids is not None: + requested = set(body.reference_ids) + active_references = {rid: cfg for rid, cfg in self._references.items() if rid in requested} + # Resolve, per reference model, the available (attempted) repeat dirs # for this task. A reference that has no deliverable for this task is # simply skipped — the eval model just isn't judged against it here. ref_dirs_by_id: Dict[str, List[Path]] = {} - for ref_id, ref_cfg in self._references.items(): + for ref_id, ref_cfg in active_references.items(): ref_task_root = Path(ref_cfg.deliverables_dir) / f"task_{body.task_id}" dirs = [d for d in _iter_ref_repeat_dirs(ref_task_root) if task_attempted(str(d))] if dirs: diff --git a/resources_servers/gdpval/multistage_elo.py b/resources_servers/gdpval/multistage_elo.py new file mode 100644 index 0000000000..ceaf28b480 --- /dev/null +++ b/resources_servers/gdpval/multistage_elo.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-stage adaptive ELO estimation for GDPVal pairwise comparison. + +Instead of comparing the evaluated model against every reference model on all +tasks, this runs a sequence of *stages*. Each stage: + +1. fixes a set of ``T`` tasks sampled from a task-distribution JSON file (see + ``responses_api_agents.stirrup_agent.task_distribution``), +2. judges the evaluated model against a set of ``M`` reference models on those + tasks (delegated to an injected ``judge_stage`` callable), +3. fits an anchored Bradley-Terry MLE ELO from that stage's win/loss/tie + battles (reusing ``comparison.calculate_mle_elo``), and +4. uses that estimate to choose the ``M`` references for the next stage. + +Across stages, ``M`` typically shrinks (zooming in on references whose known +ELO is closest to the evaluated model's current estimate) while ``T`` grows +(spending the saved judge budget on a tighter final estimate). + +This module is intentionally **pure / server-agnostic**: the actual judging +(running rollouts, calling ``/verify``, reading cached deliverables) is supplied +by the caller as a ``judge_stage`` callable, so the staging/selection/ELO logic +is unit-testable without any servers. The orchestration that wires this to the +GDPVal servers lives in the driver (see the module docstring there). +""" + +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Mapping, Optional, Sequence + +from resources_servers.gdpval.comparison import calculate_mle_elo + + +# A mapping ``ref_id -> {"wins": int, "losses": int, "ties": int, +# "reference_elo": float}`` as produced (per task, then pooled) by the GDPVal +# comparison verifier. This is the unit the ELO MLE is fit over. +PerReferenceTotals = Dict[str, Dict[str, float]] + +# Signature of the injected judging step. Given the stage's fixed task ids and +# the selected reference ids, return pooled per-reference win/loss/tie totals +# for the evaluated model across those tasks. +JudgeStageFn = Callable[[Sequence[str], Sequence[str]], PerReferenceTotals] + + +@dataclass +class StageSpec: + """Configuration for a single stage. + + ``num_tasks`` is ``T`` (the number of tasks judged this stage). ``num_models`` + is ``M`` (the number of reference models compared against); ``None`` means + "all available references" (used for the first, broad stage). ``seed`` makes + task sampling for this stage reproducible. + """ + + num_tasks: int + num_models: Optional[int] = None + seed: Optional[int] = None + + +@dataclass +class StageResult: + """Outcome of one stage.""" + + stage_index: int + task_ids: List[str] + reference_ids: List[str] + per_reference: PerReferenceTotals + eval_elo: Optional[float] + normalized_elo: Optional[float] + # Number of reference models included in this stage's ELO fit. + num_references: int + + +@dataclass +class MultiStageEloConfig: + """End-to-end configuration for a multi-stage ELO run.""" + + stages: List[StageSpec] + # ref_id -> known/anchor ELO. Both the MLE (anchors) and reference selection + # ("closest to the eval estimate") require these. + reference_elos: Dict[str, float] + + # Task distribution source. When ``distribution_path`` is unset (or missing), + # the driver builds a distribution from ``dataset_path`` (or the default + # GDPVal dataset) grouped by ``column`` and caches it. See + # ``multistage_elo_driver.ensure_distribution``. + distribution_path: Optional[str] = None + dataset_path: Optional[str] = None + + # Eval deliverables source. When set, pre-existing cached deliverables under + # this directory (``task_/repeat_/``) are reused instead of producing + # fresh rollouts. ``produce_missing`` controls whether tasks absent from the + # cache are produced on demand (True) or dropped from the stage (False). + eval_deliverables_dir: Optional[str] = None + produce_missing: bool = True + + # Sampling behaviour across stages. ``nested=True`` makes each stage's task set + # a superset of the previous stage's, which is cheaper (reuses produced + # deliverables and judgments) but couples the stages' samples. The default + # (False) samples each stage independently: later stages draw fresh tasks, so + # the stages contribute more independent information to the ELO estimate. + nested_tasks: bool = False + + selection: str = "closest" + column: List[str] = field(default_factory=lambda: ["occupation"]) + + def __post_init__(self) -> None: + if not self.stages: + raise ValueError("At least one stage is required.") + if self.selection != "closest": + raise ValueError(f"Unknown selection strategy: {self.selection!r}") + + +# --------------------------------------------------------------------------- +# Reference selection +# --------------------------------------------------------------------------- + + +def select_references( + reference_elos: Mapping[str, float], + eval_elo: Optional[float], + num_models: Optional[int], +) -> List[str]: + """Choose reference ids for a stage. + + Returns all references (sorted by id) when ``num_models`` is ``None`` or the + estimate is not yet available (the first, broad stage). Otherwise returns the + ``num_models`` references whose anchor ELO is closest to ``eval_elo``, ties + broken by ``ref_id`` for determinism. + """ + all_ids = sorted(reference_elos) + if num_models is None or eval_elo is None or num_models >= len(all_ids): + return all_ids + if num_models <= 0: + return [] + ranked = sorted(all_ids, key=lambda rid: (abs(reference_elos[rid] - eval_elo), rid)) + chosen = ranked[:num_models] + # Return in stable id order rather than distance order for readable output. + return sorted(chosen) + + +# --------------------------------------------------------------------------- +# Task planning +# --------------------------------------------------------------------------- + + +def plan_stage_task_ids( + distribution: Mapping[str, Mapping[str, object]], + stages: Sequence[StageSpec], + *, + rng: Optional[random.Random] = None, + nested: bool = True, +) -> List[List[str]]: + """Pre-sample the task set for every stage from a task distribution. + + Task selection is independent of any ELO estimate, so all stages' task sets + can be planned up front. + + ``nested=True`` makes each stage's set a superset of the previous one. We get + this for free in a single draw: ``sample_task_ids`` samples without + replacement one task at a time, so a prefix of a large draw is identical to a + smaller draw made with the same RNG. We therefore draw once, sized to the + largest stage, and slice each stage's prefix from it — O(max T) work and + exactly proportional per stage, with nesting guaranteed. A single shared RNG + is used (per-stage ``seed`` only applies to independent sampling). + + ``nested=False`` samples each stage independently, honoring its own ``seed``. + """ + from responses_api_agents.stirrup_agent.task_distribution import sample_task_ids + + base_rng = rng or random.Random() + + if not nested: + return [ + sample_task_ids( + distribution, + s.num_tasks, + rng=random.Random(s.seed) if s.seed is not None else base_rng, + ) + for s in stages + ] + + max_target = max(s.num_tasks for s in stages) + ordered = sample_task_ids(distribution, max_target, rng=base_rng) + return [list(ordered[: s.num_tasks]) for s in stages] + + +# --------------------------------------------------------------------------- +# ELO fitting +# --------------------------------------------------------------------------- + + +def fit_stage_elo( + per_reference: Mapping[str, Mapping[str, float]], + reference_elos: Mapping[str, float], +) -> tuple[Optional[float], Optional[float], int]: + """Fit the eval model's ELO for a stage from per-reference battle totals. + + A reference is included in the fit only if it has a known anchor ELO (from + ``reference_elos`` or a ``reference_elo`` recorded on its counts) and at + least one judged game (win + loss + tie > 0). + + Returns ``(elo, normalized_elo, num_references)``: + - ``num_references`` is how many references met both criteria above and were + passed to the MLE. + - ``elo`` / ``normalized_elo`` are ``None`` when no reference qualified + (``num_references == 0``) or when the MLE itself could not produce a rating; + in the latter case ``num_references`` is still > 0. + """ + battles: List[tuple[float, float, float, float]] = [] + for ref_id, counts in per_reference.items(): + ref_elo = reference_elos.get(ref_id, counts.get("reference_elo")) + if ref_elo is None: + continue + wins = float(counts.get("wins", 0) or 0) + losses = float(counts.get("losses", 0) or 0) + ties = float(counts.get("ties", 0) or 0) + if wins + losses + ties <= 0: + continue + battles.append((float(ref_elo), wins, losses, ties)) + + if not battles: + return None, None, 0 + + mle = calculate_mle_elo(battles) + if mle is None: + return None, None, len(battles) + elo, normalized = mle + return elo, normalized, len(battles) + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +class MultiStageEloRunner: + """Drive the multi-stage ELO procedure. + + ``run`` first plans every stage's task set up front (task selection does not + depend on any ELO estimate), then walks the stages sequentially: for each + stage it selects the references (closest known ELO to the running estimate), + judges the stage, fits the stage ELO, and threads that estimate into the next + stage's reference selection. Matchup judging is not the runner's concern; it + is supplied as ``judge_stage(task_ids, reference_ids) -> per_reference_totals``. + + ``run`` returns one ``StageResult`` per stage; the last stage's ``eval_elo`` + is the headline estimate. + """ + + def __init__( + self, + config: MultiStageEloConfig, + distribution: Mapping[str, Mapping[str, object]], + judge_stage: JudgeStageFn, + *, + rng: Optional[random.Random] = None, + on_event: Optional[Callable[[str, dict], None]] = None, + ) -> None: + self.config = config + self.distribution = distribution + self.judge_stage = judge_stage + self.rng = rng or random.Random() + # Optional progress hook. Called as ``on_event(name, data)`` for the + # events "planned", "stage_start", and "stage_end". Kept as a callback so + # this module performs no I/O itself; the driver/CLI does the printing. + self.on_event = on_event + + def _emit(self, name: str, **data: object) -> None: + if self.on_event is not None: + self.on_event(name, data) + + def run(self) -> List[StageResult]: + stage_task_sets = plan_stage_task_ids( + self.distribution, + self.config.stages, + rng=self.rng, + nested=self.config.nested_tasks, + ) + total_stages = len(self.config.stages) + self._emit("planned", stage_task_counts=[len(s) for s in stage_task_sets], total_stages=total_stages) + + results: List[StageResult] = [] + eval_elo: Optional[float] = None + for index, stage in enumerate(self.config.stages): + reference_ids = select_references(self.config.reference_elos, eval_elo, stage.num_models) + task_ids = stage_task_sets[index] + self._emit( + "stage_start", + index=index, + total_stages=total_stages, + reference_ids=list(reference_ids), + num_tasks=len(task_ids), + prior_elo=eval_elo, + ) + per_reference = self.judge_stage(task_ids, reference_ids) + stage_elo, normalized, num_references = fit_stage_elo(per_reference, self.config.reference_elos) + if stage_elo is not None: + eval_elo = stage_elo + self._emit( + "stage_end", + index=index, + total_stages=total_stages, + eval_elo=stage_elo, + normalized_elo=normalized, + num_references=num_references, + per_reference=dict(per_reference), + ) + results.append( + StageResult( + stage_index=index, + task_ids=list(task_ids), + reference_ids=list(reference_ids), + per_reference=dict(per_reference), + eval_elo=stage_elo, + normalized_elo=normalized, + num_references=num_references, + ) + ) + return results diff --git a/resources_servers/gdpval/multistage_elo_driver.py b/resources_servers/gdpval/multistage_elo_driver.py new file mode 100644 index 0000000000..6d39c20fb6 --- /dev/null +++ b/resources_servers/gdpval/multistage_elo_driver.py @@ -0,0 +1,686 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Driver that wires the multi-stage ELO logic to the GDPVal comparison server. + +This composes the pure staging logic in ``multistage_elo`` with the GDPVal +resources server's ``/verify`` (comparison mode). For each stage it: + +1. asks the runner to select the stage's references (closest known ELO to the + current estimate) and fix the stage's sampled tasks, +2. judges the evaluated model's cached deliverables against that reference + subset, one ``/verify`` call per (task, repeat) with the per-request + ``reference_ids`` filter, +3. pools the per-reference win/loss/tie votes and fits the stage ELO. + +The evaluated model's deliverables are read from a directory laid out as +``/task_/repeat_/`` (the same layout the Stirrup +agent persists). Point ``eval_deliverables_dir`` at deliverables produced by an +earlier run to score them with **zero rollouts**. Tasks missing from the cache +are either produced on demand via an injected ``producer`` callback or reported, +controlled by ``produce_missing``. + +The judging primitive ``verify_one`` is injected so the orchestration is +testable without a running server; ``make_http_verify_one`` provides the real +implementation that POSTs to the resources server. + +CLI usage (run from the repo root, against a running comparison server):: + + python -m resources_servers.gdpval.multistage_elo_driver \\ + --server-url http://localhost:8000 \\ + --eval-deliverables-dir /path/to/eval/deliverables \\ + --reference-elos '@refs.json' \\ + --stage 5 --stage 88:4 \\ + --output elo_summary.json + +where ``refs.json`` is ``{"": , ...}`` with ids matching the +server's configured ``reference_models``. Each stage has a set number of +tasks and reference models set like ``--stage num_tasks:num_models``. +See ``--help`` for all flags. +""" + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence + +from resources_servers.gdpval.comparison import task_attempted +from resources_servers.gdpval.multistage_elo import ( + MultiStageEloConfig, + MultiStageEloRunner, + PerReferenceTotals, + StageResult, + StageSpec, +) + + +# verify_one(task_id, deliverables_dir, prompt, reference_ids) -> verify response dict +VerifyOneFn = Callable[[str, str, str, Sequence[str]], Dict[str, Any]] +# producer(task_ids) -> None: materialize eval deliverables for the given tasks. +ProducerFn = Callable[[Sequence[str]], None] + + +# --------------------------------------------------------------------------- +# Dataset / distribution loading +# --------------------------------------------------------------------------- + + +# Default location for distributions this driver builds on demand. Lives under +# the resources server's data dir so it is reachable from wherever the driver +# runs and is easy to inspect/reuse across runs. +DEFAULT_DISTRIBUTION_CACHE_DIR = Path(__file__).resolve().parent / "data" / "distributions" + + +def load_distribution(path: str | Path) -> Dict[str, Dict[str, Any]]: + """Load a task-distribution JSON file produced by ``task_distribution.py``.""" + with Path(path).open("r", encoding="utf-8") as handle: + data = json.load(handle) + if not isinstance(data, dict): + raise ValueError(f"Distribution file {path} must be a JSON object.") + return data + + +def ensure_distribution( + distribution_path: Optional[str | Path] = None, + *, + dataset_path: Optional[str | Path] = None, + columns: Optional[Sequence[str]] = None, + cache_dir: Optional[str | Path] = None, +) -> tuple[Dict[str, Dict[str, Any]], Path]: + """Return ``(distribution, path)``, building the distribution if needed. + + If ``distribution_path`` exists it is loaded as-is. Otherwise a distribution + is built from ``dataset_path`` (or the default GDPVal dataset) grouped by + ``columns`` (default ``["occupation"]``) via ``task_distribution``, then saved + so subsequent runs reuse it. It is written to ``distribution_path`` when + given, else to ``/_distribution.json`` (cache_dir + defaults to ``DEFAULT_DISTRIBUTION_CACHE_DIR``). + """ + column_list = list(columns) if columns else ["occupation"] + + if distribution_path is not None and Path(distribution_path).is_file(): + return load_distribution(distribution_path), Path(distribution_path) + + from responses_api_agents.stirrup_agent.task_distribution import ( + build_distribution_from_dataset, + resolve_default_dataset, + ) + + resolved_dataset = Path(dataset_path) if dataset_path is not None else resolve_default_dataset() + if resolved_dataset is None: + raise FileNotFoundError( + "No distribution file was provided and no default GDPVal dataset could be found to " + "build one from. Provide distribution_path, pass dataset_path, or prepare the GDPVal " + "dataset (gym eval prepare --benchmark gdpval)." + ) + + distribution = build_distribution_from_dataset(resolved_dataset, column_list) + + if distribution_path is not None: + out_path = Path(distribution_path) + else: + base = Path(cache_dir) if cache_dir is not None else DEFAULT_DISTRIBUTION_CACHE_DIR + out_path = base / f"{'_'.join(column_list)}_distribution.json" + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as handle: + json.dump(distribution, handle, indent=2, ensure_ascii=False) + print( + f"[multistage-elo] built task distribution over {column_list} from {resolved_dataset} -> {out_path}", + flush=True, + ) + return distribution, out_path + + +def load_task_prompts(jsonl_path: str | Path) -> Dict[str, str]: + """Map ``task_id -> prompt`` from a benchmark JSONL. + + The prompt is needed when judging cached deliverables (the judge sees the + task description). Looks for ``prompt`` and ``task_id`` at the top level and, + failing that, under ``responses_create_params.metadata`` — covering both the + prepared benchmark layout and the metadata-nested layout. + """ + prompts: Dict[str, str] = {} + with Path(jsonl_path).open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + row = json.loads(line) + meta = (row.get("responses_create_params") or {}).get("metadata") or {} + task_id = row.get("task_id") or meta.get("task_id") + prompt = row.get("prompt") or meta.get("prompt") + if task_id is not None: + prompts[str(task_id)] = prompt or "" + + return prompts + + +# --------------------------------------------------------------------------- +# Cached-deliverable discovery +# --------------------------------------------------------------------------- + + +def task_repeat_dirs(eval_deliverables_dir: str | Path, task_id: str) -> List[Path]: + """Return attempted ``repeat_`` dirs (or a flat task dir) for a task. + + Mirrors the resources server's reference-repeat resolution: prefers + ``task_/repeat_/`` subdirs, falls back to a flat ``task_/``, and + only returns dirs that look like a completed run (``finish_params.json``). + """ + task_root = Path(eval_deliverables_dir) / f"task_{task_id}" + if not task_root.is_dir(): + return [] + repeats = sorted(p for p in task_root.iterdir() if p.is_dir() and p.name.startswith("repeat_")) + candidates = repeats or [task_root] + return [d for d in candidates if task_attempted(str(d))] + + +def cached_task_ids(eval_deliverables_dir: str | Path) -> set: + """All task ids that have at least one attempted deliverable in the cache.""" + root = Path(eval_deliverables_dir) + if not root.is_dir(): + return set() + found = set() + for child in root.iterdir(): + if child.is_dir() and child.name.startswith("task_"): + task_id = child.name[len("task_") :] + if task_repeat_dirs(eval_deliverables_dir, task_id): + found.add(task_id) + return found + + +def check_coverage(eval_deliverables_dir: str | Path, task_ids: Sequence[str]) -> tuple[List[str], List[str]]: + """Split ``task_ids`` into ``(present, missing)`` against the cache.""" + present, missing = [], [] + for tid in task_ids: + (present if task_repeat_dirs(eval_deliverables_dir, tid) else missing).append(tid) + + return present, missing + + +# --------------------------------------------------------------------------- +# Vote pooling +# --------------------------------------------------------------------------- + + +def pool_per_reference(verify_responses: Sequence[Mapping[str, Any]]) -> PerReferenceTotals: + """Sum ``per_reference`` win/loss/tie counts across many verify responses.""" + totals: PerReferenceTotals = {} + for vr in verify_responses: + per_ref = vr.get("per_reference") or {} + for ref_id, counts in per_ref.items(): + entry = totals.setdefault(ref_id, {"wins": 0, "losses": 0, "ties": 0, "reference_elo": None}) + entry["wins"] += int(counts.get("wins", 0) or 0) + entry["losses"] += int(counts.get("losses", 0) or 0) + entry["ties"] += int(counts.get("ties", 0) or 0) + if entry["reference_elo"] is None: + entry["reference_elo"] = counts.get("reference_elo") + + return totals + + +# --------------------------------------------------------------------------- +# judge_stage builder +# --------------------------------------------------------------------------- + + +def build_judge_stage( + verify_one: VerifyOneFn, + eval_deliverables_dir: str | Path, + task_prompts: Mapping[str, str], + *, + produce_missing: bool = True, + producer: Optional[ProducerFn] = None, + progress: Optional[Callable[[int, int, str], None]] = None, +): + """Build the ``judge_stage`` callable expected by ``MultiStageEloRunner``. + + For each stage's tasks, judges the cached eval deliverables against the + selected references (one ``verify_one`` call per task-repeat) and pools the + per-reference votes. Missing tasks are produced via ``producer`` when given; + otherwise ``produce_missing=True`` raises an actionable error and + ``produce_missing=False`` drops them with a warning. + + ``progress`` is an optional callback invoked as ``progress(done, total, + task_id)`` after each ``verify_one`` completes, for live status reporting. + """ + + def judge_stage(task_ids: Sequence[str], reference_ids: Sequence[str]) -> PerReferenceTotals: + present, missing = check_coverage(eval_deliverables_dir, task_ids) + if missing: + if producer is not None: + producer(missing) + present, missing = check_coverage(eval_deliverables_dir, task_ids) + if missing and produce_missing and producer is None: + raise FileNotFoundError( + f"{len(missing)} task(s) have no cached eval deliverable under " + f"{eval_deliverables_dir} (e.g. {missing[:3]}). Produce them first with an " + f"execute_only run, pass a producer, or set produce_missing=False to skip them." + ) + if missing: + print( + f"[multistage-elo] WARNING: skipping {len(missing)} task(s) with no cached " + f"deliverable (e.g. {missing[:3]})", + flush=True, + ) + + # Flatten to (task_id, repeat_dir) units up front so progress can report + # an accurate done/total across all repeats in the stage. + units = [(tid, repeat_dir) for tid in present for repeat_dir in task_repeat_dirs(eval_deliverables_dir, tid)] + total = len(units) + responses: List[Dict[str, Any]] = [] + for done, (task_id, repeat_dir) in enumerate(units, start=1): + prompt = task_prompts.get(task_id, "") + responses.append(verify_one(task_id, str(repeat_dir), prompt, list(reference_ids))) + if progress is not None: + progress(done, total, task_id) + return pool_per_reference(responses) + + return judge_stage + + +# --------------------------------------------------------------------------- +# Real verify_one (HTTP) +# --------------------------------------------------------------------------- + + +def build_verify_request_body( + task_id: str, + deliverables_dir: str, + prompt: str, + reference_ids: Sequence[str], + *, + model: str = "eval", +) -> Dict[str, Any]: + """Build a minimal comparison-mode ``/verify`` request body. + + In comparison mode the judge reads deliverable files from ``deliverables_dir`` + rather than the response payload, so a placeholder response is sufficient. + """ + return { + "responses_create_params": {"input": [], "model": model}, + "response": { + "id": f"multistage-{task_id}", + "created_at": 0, + "model": model, + "object": "response", + "output": [], + "parallel_tool_calls": False, + "tool_choice": "none", + "tools": [], + }, + "task_id": task_id, + "prompt": prompt, + "deliverables_dir": deliverables_dir, + "reference_ids": list(reference_ids), + } + + +def make_http_verify_one(server_url: str, *, timeout: float = 1800.0, model: str = "eval") -> VerifyOneFn: + """Return a blocking ``verify_one`` that POSTs to a running resources server. + + ``server_url`` is the resources server base URL (e.g. ``http://host:port``); + ``/verify`` is appended. Uses stdlib ``urllib`` so the driver pulls in no + async machinery — it is a standalone orchestration script, not part of the + server hot path. + """ + import urllib.request + + endpoint = server_url.rstrip("/") + "/verify" + + def verify_one(task_id: str, deliverables_dir: str, prompt: str, reference_ids: Sequence[str]) -> Dict[str, Any]: + body = build_verify_request_body(task_id, deliverables_dir, prompt, reference_ids, model=model) + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request(endpoint, data=data, headers={"Content-Type": "application/json"}) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + return verify_one + + +# --------------------------------------------------------------------------- +# Top-level run +# --------------------------------------------------------------------------- + + +def run_multistage_elo( + config: MultiStageEloConfig, + verify_one: VerifyOneFn, + task_prompts: Mapping[str, str], + *, + rng=None, + producer: Optional[ProducerFn] = None, + on_event: Optional[Callable[[str, dict], None]] = None, + progress: Optional[Callable[[int, int, str], None]] = None, +) -> List[StageResult]: + """Run the full multi-stage ELO procedure and return per-stage results. + + ``config.eval_deliverables_dir`` must be set — it is the source of the eval + model's (cached or produced) deliverables. + + ``on_event``/``progress`` are optional callbacks for live status reporting: + ``on_event`` receives stage-level events (see ``MultiStageEloRunner``) and + ``progress`` receives per-(task, repeat) judging progress. + """ + if not config.eval_deliverables_dir: + raise ValueError("config.eval_deliverables_dir must be set (source of eval deliverables).") + + distribution, _ = ensure_distribution( + config.distribution_path, + dataset_path=config.dataset_path, + columns=config.column, + ) + judge_stage = build_judge_stage( + verify_one, + config.eval_deliverables_dir, + task_prompts, + produce_missing=config.produce_missing, + producer=producer, + progress=progress, + ) + runner = MultiStageEloRunner(config, distribution, judge_stage, rng=rng, on_event=on_event) + return runner.run() + + +def stage_results_to_dict(results: Sequence[StageResult]) -> Dict[str, Any]: + """Serialize stage results to a JSON-friendly summary dict.""" + final = results[-1] if results else None + return { + "final_eval_elo": final.eval_elo if final else None, + "final_normalized_elo": final.normalized_elo if final else None, + "num_stages": len(results), + "stages": [ + { + "stage_index": r.stage_index, + "num_tasks": len(r.task_ids), + "reference_ids": r.reference_ids, + "eval_elo": r.eval_elo, + "normalized_elo": r.normalized_elo, + "num_references": r.num_references, + "per_reference": r.per_reference, + "task_ids": r.task_ids, + } + for r in results + ], + } + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +DEFAULT_TASK_PROMPTS = "benchmarks/gdpval/data/gdpval_benchmark.jsonl" + + +def _parse_stage(spec: str) -> StageSpec: + """Parse a ``--stage`` value ``num_tasks[:num_models[:seed]]`` into a StageSpec. + + ``num_models`` may be ``all`` or empty for "all available references". Examples: + ``5`` (5 tasks, all refs), ``88:4`` (88 tasks, 4 closest refs), ``5:all:7`` + (5 tasks, all refs, seed 7). + """ + parts = spec.split(":") + if not parts or not parts[0].strip(): + raise argparse.ArgumentTypeError(f"Invalid --stage {spec!r}: num_tasks is required.") + try: + num_tasks = int(parts[0]) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid --stage {spec!r}: num_tasks must be an integer.") + + num_models: Optional[int] = None + if len(parts) >= 2 and parts[1].strip() and parts[1].strip().lower() != "all": + try: + num_models = int(parts[1]) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid --stage {spec!r}: num_models must be an integer or 'all'.") + + seed: Optional[int] = None + if len(parts) >= 3 and parts[2].strip(): + try: + seed = int(parts[2]) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid --stage {spec!r}: seed must be an integer.") + + return StageSpec(num_tasks=num_tasks, num_models=num_models, seed=seed) + + +def _load_reference_elos(value: str) -> Dict[str, float]: + """Load reference ELOs from inline JSON or, if prefixed with ``@``, a JSON file. + + Accepts ``{"ref_id": elo, ...}``. The ids must match the running server's + ``reference_models`` ids. + """ + text = value + if value.startswith("@"): + text = Path(value[1:]).read_text(encoding="utf-8") + data = json.loads(text) + if not isinstance(data, dict) or not data: + raise argparse.ArgumentTypeError("--reference-elos must be a non-empty JSON object of {ref_id: elo}.") + return {str(k): float(v) for k, v in data.items()} + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="multistage_elo", + description=( + "Run multi-stage adaptive ELO estimation for a model's GDPVal deliverables " + "against a running GDPVal comparison server." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Example:\n" + " python -m resources_servers.gdpval.multistage_elo_driver \\\n" + " --server-url http://localhost:8000 \\\n" + " --eval-deliverables-dir /path/to/eval/deliverables \\\n" + " --reference-elos '@refs.json' \\\n" + " --stage 5 --stage 88:4 \\\n" + " --output elo_summary.json\n" + ), + ) + parser.add_argument( + "--server-url", + required=True, + help="Base URL of the running GDPVal comparison-mode resources server (e.g. http://localhost:8000).", + ) + parser.add_argument( + "--eval-deliverables-dir", + required=True, + help="Directory of the evaluated model's deliverables (task_/repeat_/ layout).", + ) + parser.add_argument( + "--reference-elos", + required=True, + type=_load_reference_elos, + metavar="JSON", + help=( + "Reference anchor ELOs as inline JSON ('{\"ref\": 1500, ...}') or '@path.json'. " + "Keys must match the server's reference_models ids." + ), + ) + parser.add_argument( + "--stage", + dest="stages", + action="append", + required=True, + type=_parse_stage, + metavar="N[:M[:SEED]]", + help=( + "A stage as num_tasks[:num_models[:seed]] (num_models 'all' or omitted = all references). " + "Repeat for multiple stages, e.g. --stage 5 --stage 88:4." + ), + ) + parser.add_argument( + "--task-prompts", + default=DEFAULT_TASK_PROMPTS, + help=f"Benchmark JSONL mapping task_id -> prompt (default: {DEFAULT_TASK_PROMPTS}).", + ) + parser.add_argument( + "--distribution", + default=None, + help="Existing task-distribution JSON to sample tasks from. If omitted, one is built and cached.", + ) + parser.add_argument( + "--dataset", + default=None, + help="Dataset JSONL to build the distribution from when --distribution is not given (default: GDPVal).", + ) + parser.add_argument( + "--column", + dest="columns", + action="append", + default=None, + metavar="COLUMN", + help="Column(s) to group the distribution by when building one (default: occupation). Repeatable.", + ) + parser.add_argument( + "--nested-tasks", + action="store_true", + help="Make each stage's task set a superset of the previous (default: independent per-stage sampling).", + ) + parser.add_argument( + "--skip-missing", + action="store_true", + help="Drop tasks with no cached eval deliverable instead of erroring (sets produce_missing=False).", + ) + parser.add_argument( + "--model", + default="eval", + help="Label for the evaluated model in verify requests (default: eval).", + ) + parser.add_argument( + "--timeout", + type=float, + default=1800.0, + help="Per-request /verify timeout in seconds (default: 1800).", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Top-level RNG seed for reproducible task sampling and reference selection.", + ) + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Suppress live per-stage / per-task progress output on stderr.", + ) + parser.add_argument( + "--output", + "-o", + default=None, + help="Path to write the JSON ELO summary. Defaults to stdout.", + ) + return parser + + +def _make_progress_printers(): + """Return ``(on_event, progress)`` callbacks that print human-readable status to stderr. + + ``on_event`` prints a banner at the start/end of each stage (selected + references, task count, fitted ELO); ``progress`` prints a per-(task, repeat) + counter as each ``/verify`` completes. + """ + + def on_event(name: str, data: dict) -> None: + if name == "planned": + counts = data.get("stage_task_counts", []) + print( + f"[multistage-elo] planned {data.get('total_stages')} stage(s); tasks per stage: {counts}", + file=sys.stderr, + flush=True, + ) + elif name == "stage_start": + idx = int(data["index"]) + 1 + total = data["total_stages"] + refs = data.get("reference_ids", []) + prior = data.get("prior_elo") + prior_str = f"{prior:.1f}" if isinstance(prior, (int, float)) else "n/a" + print( + f"[multistage-elo] stage {idx}/{total}: {data.get('num_tasks')} task(s) " + f"vs {len(refs)} ref(s) {refs} (prior ELO: {prior_str})", + file=sys.stderr, + flush=True, + ) + elif name == "stage_end": + idx = int(data["index"]) + 1 + total = data["total_stages"] + elo = data.get("eval_elo") + elo_str = f"{elo:.1f}" if isinstance(elo, (int, float)) else "unset (no games)" + print( + f"[multistage-elo] stage {idx}/{total} done: eval ELO = {elo_str} " + f"(fit over {data.get('num_references')} ref(s))", + file=sys.stderr, + flush=True, + ) + + def progress(done: int, total: int, task_id: str) -> None: + short = task_id[:18] + "…" if len(task_id) > 19 else task_id + end = "\n" if done == total else "\r" + print(f"[multistage-elo] judged {done}/{total} (task {short}) ", end=end, file=sys.stderr, flush=True) + + return on_event, progress + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = _build_arg_parser() + args = parser.parse_args(argv) + + eval_dir = Path(args.eval_deliverables_dir) + if not eval_dir.is_dir(): + print(f"Eval deliverables dir not found: {eval_dir}", file=sys.stderr) + return 2 + + prompts_path = Path(args.task_prompts) + if not prompts_path.is_file(): + print(f"Task prompts JSONL not found: {prompts_path}", file=sys.stderr) + return 2 + + config = MultiStageEloConfig( + stages=list(args.stages), + reference_elos=args.reference_elos, + distribution_path=args.distribution, + dataset_path=args.dataset, + eval_deliverables_dir=str(eval_dir), + produce_missing=not args.skip_missing, + nested_tasks=args.nested_tasks, + column=list(args.columns) if args.columns else ["occupation"], + ) + + verify_one = make_http_verify_one(args.server_url, timeout=args.timeout, model=args.model) + task_prompts = load_task_prompts(prompts_path) + rng = random.Random(args.seed) if args.seed is not None else None + + on_event, progress = (None, None) if args.quiet else _make_progress_printers() + results = run_multistage_elo(config, verify_one, task_prompts, rng=rng, on_event=on_event, progress=progress) + payload = json.dumps(stage_results_to_dict(results), indent=2, ensure_ascii=False) + + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(payload + "\n", encoding="utf-8") + final = results[-1] if results else None + final_elo = final.eval_elo if final else None + print(f"Wrote ELO summary ({len(results)} stages, final_eval_elo={final_elo}) to {out_path}", file=sys.stderr) + else: + print(payload) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/resources_servers/gdpval/tests/test_app.py b/resources_servers/gdpval/tests/test_app.py index b2585b9507..a61652f0e5 100644 --- a/resources_servers/gdpval/tests/test_app.py +++ b/resources_servers/gdpval/tests/test_app.py @@ -595,6 +595,74 @@ def fake_run_trials(**_kwargs): assert resp.total_losses == 2 assert resp.judge_response["reference_count"] == 2 + @pytest.mark.asyncio + async def test_reference_ids_filter_judges_subset(self, tmp_path) -> None: + """``reference_ids`` on the verify request restricts judging to the named + references; unknown ids are ignored.""" + eval_dir = tmp_path / "eval" / "task_task-1" / "repeat_0" + eval_dir.mkdir(parents=True) + (eval_dir / "finish_params.json").write_text("{}") + + ref_roots = {} + for ref_id in ("kimi", "gpt5"): + root = tmp_path / ref_id + td = root / "task_task-1" + td.mkdir(parents=True) + (td / "finish_params.json").write_text("{}") + ref_roots[ref_id] = root + + server = _server( + reward_mode="comparison", + reference_models={ + "kimi": {"deliverables_dir": str(ref_roots["kimi"]), "elo": 1290.0}, + "gpt5": {"deliverables_dir": str(ref_roots["gpt5"]), "elo": 1320.0}, + }, + preconvert_office_to_pdf=False, + num_comparison_trials=4, + ) + + def fake_run_trials(**_kwargs): + return {"winner": "[[B]]", "win_count_a": 1, "win_count_b": 3, "tie_count": 0, "task_count": 4} + + # Only judge against gpt5 (and an unknown id, which is ignored). + body = _verify_request(deliverables_dir=str(eval_dir), reference_ids=["gpt5", "nonexistent"]) + + with ( + patch("resources_servers.gdpval.comparison.run_trials", side_effect=fake_run_trials), + patch("resources_servers.gdpval.app.get_server_url", return_value="http://localhost:9999"), + patch("resources_servers.gdpval.comparison.build_file_section", return_value=[]), + patch("openai.OpenAI", return_value=MagicMock()), + ): + resp = await server.verify(body) + + assert set(resp.per_reference) == {"gpt5"} + assert resp.total_wins == 3 + assert resp.total_losses == 1 + assert resp.judge_response["reference_count"] == 1 + + @pytest.mark.asyncio + async def test_reference_ids_empty_yields_no_references(self, tmp_path) -> None: + """An empty ``reference_ids`` list judges against nothing → reference_missing.""" + eval_dir = tmp_path / "eval" / "task_task-1" / "repeat_0" + eval_dir.mkdir(parents=True) + (eval_dir / "finish_params.json").write_text("{}") + root = tmp_path / "kimi" + (root / "task_task-1").mkdir(parents=True) + (root / "task_task-1" / "finish_params.json").write_text("{}") + + server = _server( + reward_mode="comparison", + reference_models={"kimi": {"deliverables_dir": str(root), "elo": 1290.0}}, + preconvert_office_to_pdf=False, + ) + body = _verify_request(deliverables_dir=str(eval_dir), reference_ids=[]) + + with patch("resources_servers.gdpval.app.get_server_url", return_value="http://localhost:9999"): + resp = await server.verify(body) + + assert resp.reward == 0.0 + assert resp.judge_response == {"error": "reference_missing"} + @staticmethod def _two_ref_server_and_body(tmp_path): eval_dir = tmp_path / "eval" / "task_task-1" / "repeat_0" diff --git a/resources_servers/gdpval/tests/test_multistage_elo.py b/resources_servers/gdpval/tests/test_multistage_elo.py new file mode 100644 index 0000000000..d0b0913217 --- /dev/null +++ b/resources_servers/gdpval/tests/test_multistage_elo.py @@ -0,0 +1,206 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import pytest + +from resources_servers.gdpval.multistage_elo import ( + MultiStageEloConfig, + MultiStageEloRunner, + StageSpec, + fit_stage_elo, + plan_stage_task_ids, + select_references, +) + + +def _dist(groups): + """groups: {key: [task_ids]} -> distribution dict with proportional pct.""" + total = sum(len(v) for v in groups.values()) or 1 + return {k: {"percentage": len(v) / total, "task_ids": list(v)} for k, v in groups.items()} + + +class TestSelectReferences: + ELOS = {"a": 1000.0, "b": 1200.0, "c": 1300.0, "d": 1500.0} + + def test_all_when_num_models_none(self) -> None: + assert select_references(self.ELOS, 1234.0, None) == ["a", "b", "c", "d"] + + def test_all_when_eval_elo_none(self) -> None: + assert select_references(self.ELOS, None, 2) == ["a", "b", "c", "d"] + + def test_all_when_num_models_exceeds_available(self) -> None: + assert select_references(self.ELOS, 1234.0, 10) == ["a", "b", "c", "d"] + + def test_closest_subset(self) -> None: + # eval 1250 -> closest are c(1300,50) and b(1200,50); tie broken by id. + assert select_references(self.ELOS, 1250.0, 2) == ["b", "c"] + + def test_closest_single(self) -> None: + assert select_references(self.ELOS, 1490.0, 1) == ["d"] + + def test_zero_models_returns_empty(self) -> None: + assert select_references(self.ELOS, 1250.0, 0) == [] + + def test_result_sorted_by_id(self) -> None: + chosen = select_references(self.ELOS, 1100.0, 3) + assert chosen == sorted(chosen) + + +class TestPlanStageTaskIds: + def test_nested_is_superset(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(10)], "y": [f"y{i}" for i in range(10)]}) + stages = [StageSpec(num_tasks=3), StageSpec(num_tasks=8)] + planned = plan_stage_task_ids(dist, stages, rng=random.Random(0), nested=True) + assert len(planned[0]) == 3 + assert len(planned[1]) == 8 + assert set(planned[0]).issubset(set(planned[1])) + + def test_nested_no_duplicates(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(20)]}) + stages = [StageSpec(num_tasks=5), StageSpec(num_tasks=12)] + planned = plan_stage_task_ids(dist, stages, rng=random.Random(1), nested=True) + assert len(planned[1]) == len(set(planned[1])) + + def test_nested_capped_at_available(self) -> None: + dist = _dist({"x": ["a", "b", "c"]}) + stages = [StageSpec(num_tasks=2), StageSpec(num_tasks=100)] + planned = plan_stage_task_ids(dist, stages, rng=random.Random(2), nested=True) + assert sorted(planned[1]) == ["a", "b", "c"] + + def test_non_increasing_stage_reuses_prefix(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(10)]}) + stages = [StageSpec(num_tasks=5), StageSpec(num_tasks=3)] + planned = plan_stage_task_ids(dist, stages, rng=random.Random(3), nested=True) + assert planned[1] == planned[0][:3] + + def test_independent_sampling(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(50)]}) + stages = [StageSpec(num_tasks=5, seed=1), StageSpec(num_tasks=5, seed=2)] + planned = plan_stage_task_ids(dist, stages, nested=False) + assert len(planned[0]) == 5 and len(planned[1]) == 5 + + def test_seed_reproducible(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(50)]}) + stages = [StageSpec(num_tasks=7, seed=42)] + a = plan_stage_task_ids(dist, stages, nested=False) + b = plan_stage_task_ids(dist, stages, nested=False) + assert a == b + + +class TestFitStageElo: + ELOS = {"a": 1000.0, "b": 1400.0} + + def test_no_battles_returns_none(self) -> None: + assert fit_stage_elo({}, self.ELOS) == (None, None, 0) + + def test_zero_games_skipped(self) -> None: + per_ref = {"a": {"wins": 0, "losses": 0, "ties": 0}} + assert fit_stage_elo(per_ref, self.ELOS) == (None, None, 0) + + def test_fits_elo_uses_config_anchor(self) -> None: + per_ref = {"a": {"wins": 5, "losses": 5, "ties": 0}} + elo, norm, n = fit_stage_elo(per_ref, self.ELOS) + # 50% win rate vs a single anchor -> eval elo ~= anchor elo. + assert n == 1 + assert elo == pytest.approx(1000.0, abs=1.0) + assert norm == pytest.approx((elo - 500.0) / 2000.0) + + def test_falls_back_to_recorded_reference_elo(self) -> None: + per_ref = {"z": {"wins": 5, "losses": 5, "ties": 0, "reference_elo": 1100.0}} + elo, _norm, n = fit_stage_elo(per_ref, {}) + assert n == 1 + assert elo == pytest.approx(1100.0, abs=1.0) + + def test_multi_reference_battles(self) -> None: + per_ref = { + "a": {"wins": 8, "losses": 2, "ties": 0}, + "b": {"wins": 2, "losses": 8, "ties": 0}, + } + elo, _norm, n = fit_stage_elo(per_ref, self.ELOS) + assert n == 2 + assert 1000.0 < elo < 1400.0 + + +class TestMultiStageEloRunner: + def _config(self, **overrides): + base = dict( + distribution_path="unused.json", + stages=[StageSpec(num_tasks=3, num_models=None), StageSpec(num_tasks=6, num_models=2)], + reference_elos={"a": 1000.0, "b": 1200.0, "c": 1300.0, "d": 1500.0}, + ) + base.update(overrides) + return MultiStageEloConfig(**base) + + def test_requires_stages(self) -> None: + with pytest.raises(ValueError): + MultiStageEloConfig(distribution_path="x", stages=[], reference_elos={}) + + def test_unknown_selection_rejected(self) -> None: + with pytest.raises(ValueError): + MultiStageEloConfig(distribution_path="x", stages=[StageSpec(1)], reference_elos={}, selection="zzz") + + def test_two_stage_flow_threads_elo_and_shrinks_refs(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(20)]}) + seen_stage_refs = [] + + def judge_stage(task_ids, reference_ids): + seen_stage_refs.append(list(reference_ids)) + # Eval beats everyone 7-3 -> high elo estimate. + return {rid: {"wins": 7, "losses": 3, "ties": 0} for rid in reference_ids} + + runner = MultiStageEloRunner(self._config(nested_tasks=True), dist, judge_stage, rng=random.Random(0)) + results = runner.run() + + assert len(results) == 2 + # Stage 1 uses all references. + assert seen_stage_refs[0] == ["a", "b", "c", "d"] + # Stage 2 narrows to 2 references (closest to the stage-1 estimate). + assert len(seen_stage_refs[1]) == 2 + assert set(seen_stage_refs[1]).issubset({"a", "b", "c", "d"}) + # Nested task sets (nested_tasks=True): stage 2 superset of stage 1. + assert set(results[0].task_ids).issubset(set(results[1].task_ids)) + assert results[1].eval_elo is not None + + def test_stage_with_no_games_leaves_elo_unset(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(10)]}) + + def judge_stage(task_ids, reference_ids): + return {} + + cfg = self._config(stages=[StageSpec(num_tasks=2, num_models=None)]) + results = MultiStageEloRunner(cfg, dist, judge_stage, rng=random.Random(0)).run() + assert results[0].eval_elo is None + assert results[0].num_references == 0 + + def test_on_event_emits_lifecycle_events(self) -> None: + dist = _dist({"x": [f"x{i}" for i in range(10)]}) + + def judge_stage(task_ids, reference_ids): + return {rid: {"wins": 6, "losses": 4, "ties": 0} for rid in reference_ids} + + events = [] + cfg = self._config(stages=[StageSpec(num_tasks=2, num_models=None), StageSpec(num_tasks=3, num_models=2)]) + MultiStageEloRunner( + cfg, dist, judge_stage, rng=random.Random(0), on_event=lambda name, data: events.append((name, data)) + ).run() + + names = [n for n, _ in events] + assert names[0] == "planned" + assert names.count("stage_start") == 2 + assert names.count("stage_end") == 2 + # stage_start carries the selected references and task count. + first_start = next(d for n, d in events if n == "stage_start") + assert first_start["num_tasks"] == 2 + assert first_start["reference_ids"] == ["a", "b", "c", "d"] diff --git a/resources_servers/gdpval/tests/test_multistage_elo_driver.py b/resources_servers/gdpval/tests/test_multistage_elo_driver.py new file mode 100644 index 0000000000..09d5d80e32 --- /dev/null +++ b/resources_servers/gdpval/tests/test_multistage_elo_driver.py @@ -0,0 +1,444 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import random +from pathlib import Path + +import pytest + +import resources_servers.gdpval.multistage_elo_driver as driver +from resources_servers.gdpval.multistage_elo import MultiStageEloConfig, StageResult, StageSpec +from resources_servers.gdpval.multistage_elo_driver import ( + _load_reference_elos, + _parse_stage, + build_judge_stage, + build_verify_request_body, + cached_task_ids, + check_coverage, + ensure_distribution, + load_distribution, + load_task_prompts, + main, + pool_per_reference, + run_multistage_elo, + stage_results_to_dict, + task_repeat_dirs, +) + + +def _make_cache(root: Path, task_ids, repeats=("repeat_0",)): + for tid in task_ids: + for rep in repeats: + d = root / f"task_{tid}" / rep + d.mkdir(parents=True) + (d / "finish_params.json").write_text("{}") + + +def _dist(groups): + total = sum(len(v) for v in groups.values()) or 1 + return {k: {"percentage": len(v) / total, "task_ids": list(v)} for k, v in groups.items()} + + +class TestCacheDiscovery: + def test_task_repeat_dirs_lists_attempted_repeats(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a"], repeats=("repeat_0", "repeat_1")) + dirs = task_repeat_dirs(tmp_path, "a") + assert [d.name for d in dirs] == ["repeat_0", "repeat_1"] + + def test_task_repeat_dirs_skips_unattempted(self, tmp_path: Path) -> None: + (tmp_path / "task_a" / "repeat_0").mkdir(parents=True) # no finish_params.json + assert task_repeat_dirs(tmp_path, "a") == [] + + def test_task_repeat_dirs_flat_layout(self, tmp_path: Path) -> None: + d = tmp_path / "task_a" + d.mkdir(parents=True) + (d / "finish_params.json").write_text("{}") + assert [p.name for p in task_repeat_dirs(tmp_path, "a")] == ["task_a"] + + def test_missing_task_returns_empty(self, tmp_path: Path) -> None: + assert task_repeat_dirs(tmp_path, "ghost") == [] + + def test_cached_task_ids(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a", "b"]) + assert cached_task_ids(tmp_path) == {"a", "b"} + + def test_cached_task_ids_missing_dir(self, tmp_path: Path) -> None: + assert cached_task_ids(tmp_path / "nope") == set() + + def test_check_coverage(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a", "c"]) + present, missing = check_coverage(tmp_path, ["a", "b", "c"]) + assert present == ["a", "c"] + assert missing == ["b"] + + +class TestPoolPerReference: + def test_sums_counts_and_keeps_elo(self) -> None: + responses = [ + {"per_reference": {"a": {"wins": 2, "losses": 1, "ties": 0, "reference_elo": 1000.0}}}, + {"per_reference": {"a": {"wins": 1, "losses": 0, "ties": 1, "reference_elo": 1000.0}}}, + ] + pooled = pool_per_reference(responses) + assert pooled["a"]["wins"] == 3 + assert pooled["a"]["losses"] == 1 + assert pooled["a"]["ties"] == 1 + assert pooled["a"]["reference_elo"] == 1000.0 + + def test_handles_missing_per_reference(self) -> None: + assert pool_per_reference([{}, {"per_reference": None}]) == {} + + +class TestLoaders: + def test_load_distribution(self, tmp_path: Path) -> None: + p = tmp_path / "d.json" + p.write_text(json.dumps(_dist({"x": ["a"]}))) + assert load_distribution(p)["x"]["task_ids"] == ["a"] + + def test_load_distribution_rejects_non_object(self, tmp_path: Path) -> None: + p = tmp_path / "d.json" + p.write_text("[1,2,3]") + with pytest.raises(ValueError): + load_distribution(p) + + def test_load_task_prompts_top_level(self, tmp_path: Path) -> None: + p = tmp_path / "b.jsonl" + p.write_text(json.dumps({"task_id": "a", "prompt": "do x"}) + "\n") + assert load_task_prompts(p) == {"a": "do x"} + + def test_load_task_prompts_metadata_nested(self, tmp_path: Path) -> None: + p = tmp_path / "b.jsonl" + p.write_text(json.dumps({"responses_create_params": {"metadata": {"task_id": "a", "prompt": "y"}}}) + "\n") + assert load_task_prompts(p) == {"a": "y"} + + +class TestEnsureDistribution: + def test_loads_existing_file(self, tmp_path: Path) -> None: + p = tmp_path / "d.json" + p.write_text(json.dumps(_dist({"x": ["a"]}))) + dist, path = ensure_distribution(str(p)) + assert path == p + assert dist["x"]["task_ids"] == ["a"] + + def test_builds_from_dataset_when_missing(self, tmp_path: Path) -> None: + dataset = tmp_path / "tasks.jsonl" + rows = [ + {"task_id": "t1", "occupation": "Lawyer"}, + {"task_id": "t2", "occupation": "Lawyer"}, + {"task_id": "t3", "occupation": "Nurse"}, + ] + dataset.write_text("\n".join(json.dumps(r) for r in rows) + "\n") + cache = tmp_path / "cache" + + dist, path = ensure_distribution(None, dataset_path=str(dataset), cache_dir=str(cache)) + + assert path == cache / "occupation_distribution.json" + assert path.is_file() + assert dist["Lawyer"]["task_ids"] == ["t1", "t2"] + assert dist["Nurse"]["task_ids"] == ["t3"] + + def test_writes_to_distribution_path_when_given(self, tmp_path: Path) -> None: + dataset = tmp_path / "tasks.jsonl" + dataset.write_text(json.dumps({"task_id": "t1", "occupation": "Lawyer"}) + "\n") + out = tmp_path / "sub" / "mydist.json" + + _dist_, path = ensure_distribution(str(out), dataset_path=str(dataset)) + + assert path == out + assert out.is_file() + + def test_custom_columns_in_filename(self, tmp_path: Path) -> None: + dataset = tmp_path / "tasks.jsonl" + dataset.write_text(json.dumps({"task_id": "t1", "sector": "Legal", "occupation": "Lawyer"}) + "\n") + cache = tmp_path / "cache" + _dist_, path = ensure_distribution( + None, dataset_path=str(dataset), columns=["sector", "occupation"], cache_dir=str(cache) + ) + assert path == cache / "sector_occupation_distribution.json" + + def test_raises_when_no_dataset_available(self, tmp_path: Path, monkeypatch) -> None: + import responses_api_agents.stirrup_agent.task_distribution as td + + monkeypatch.setattr(td, "DEFAULT_DATASET_CANDIDATES", (tmp_path / "missing.jsonl",)) + with pytest.raises(FileNotFoundError): + ensure_distribution(None, cache_dir=str(tmp_path / "cache")) + + +class TestBuildVerifyRequestBody: + def test_includes_reference_ids_and_deliverables(self) -> None: + body = build_verify_request_body("t1", "/cache/task_t1/repeat_0", "prompt", ["a", "b"]) + assert body["task_id"] == "t1" + assert body["deliverables_dir"] == "/cache/task_t1/repeat_0" + assert body["reference_ids"] == ["a", "b"] + assert body["prompt"] == "prompt" + + +class TestBuildJudgeStage: + def test_judges_present_tasks_and_pools(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a", "b"], repeats=("repeat_0", "repeat_1")) + calls = [] + + def fake_verify_one(task_id, deliverables_dir, prompt, reference_ids): + calls.append((task_id, Path(deliverables_dir).name, tuple(reference_ids))) + return {"per_reference": {reference_ids[0]: {"wins": 1, "losses": 0, "ties": 0, "reference_elo": 1000.0}}} + + judge = build_judge_stage(fake_verify_one, tmp_path, {"a": "pa", "b": "pb"}) + pooled = judge(["a", "b"], ["ref1"]) + # 2 tasks x 2 repeats = 4 verify calls. + assert len(calls) == 4 + assert pooled["ref1"]["wins"] == 4 + + def test_missing_raises_when_no_producer(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a"]) + judge = build_judge_stage(lambda *a: {}, tmp_path, {}) + with pytest.raises(FileNotFoundError): + judge(["a", "missing"], ["ref1"]) + + def test_missing_skipped_when_produce_missing_false(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a"]) + + def fake_verify_one(task_id, deliverables_dir, prompt, reference_ids): + return {"per_reference": {"ref1": {"wins": 1, "losses": 0, "ties": 0, "reference_elo": 1000.0}}} + + judge = build_judge_stage(fake_verify_one, tmp_path, {"a": ""}, produce_missing=False) + pooled = judge(["a", "missing"], ["ref1"]) + assert pooled["ref1"]["wins"] == 1 + + def test_producer_materializes_then_judges(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a"]) + + def producer(task_ids): + _make_cache(tmp_path, list(task_ids)) + + def fake_verify_one(task_id, deliverables_dir, prompt, reference_ids): + return {"per_reference": {"ref1": {"wins": 1, "losses": 0, "ties": 0, "reference_elo": 1000.0}}} + + judge = build_judge_stage(fake_verify_one, tmp_path, {}, producer=producer) + pooled = judge(["a", "b"], ["ref1"]) + assert pooled["ref1"]["wins"] == 2 # both tasks judged after production + + def test_progress_callback_reports_each_unit(self, tmp_path: Path) -> None: + _make_cache(tmp_path, ["a", "b"], repeats=("repeat_0", "repeat_1")) + + def fake_verify_one(task_id, deliverables_dir, prompt, reference_ids): + return {"per_reference": {"ref1": {"wins": 1, "losses": 0, "ties": 0, "reference_elo": 1000.0}}} + + seen = [] + judge = build_judge_stage( + fake_verify_one, tmp_path, {}, progress=lambda done, total, tid: seen.append((done, total, tid)) + ) + judge(["a", "b"], ["ref1"]) + # 2 tasks x 2 repeats = 4 units; progress reports running done/total. + assert [s[0] for s in seen] == [1, 2, 3, 4] + assert all(s[1] == 4 for s in seen) + + +class TestRunMultistageElo: + def test_requires_eval_dir(self, tmp_path: Path) -> None: + cfg = MultiStageEloConfig(distribution_path="x.json", stages=[StageSpec(1)], reference_elos={"a": 1000.0}) + with pytest.raises(ValueError): + run_multistage_elo(cfg, lambda *a: {}, {}) + + def test_end_to_end_with_fakes(self, tmp_path: Path) -> None: + # 30 cached tasks, 2-stage adaptive run with a fake judge. + task_ids = [f"t{i}" for i in range(30)] + _make_cache(tmp_path, task_ids) + dist_path = tmp_path / "dist.json" + dist_path.write_text(json.dumps(_dist({"x": task_ids}))) + + def fake_verify_one(task_id, deliverables_dir, prompt, reference_ids): + return { + "per_reference": { + rid: {"wins": 7, "losses": 3, "ties": 0, "reference_elo": elo} + for rid, elo in {"a": 1000.0, "b": 1200.0, "c": 1300.0, "d": 1500.0}.items() + if rid in reference_ids + } + } + + cfg = MultiStageEloConfig( + distribution_path=str(dist_path), + stages=[StageSpec(num_tasks=5, num_models=None), StageSpec(num_tasks=12, num_models=2)], + reference_elos={"a": 1000.0, "b": 1200.0, "c": 1300.0, "d": 1500.0}, + eval_deliverables_dir=str(tmp_path), + ) + results = run_multistage_elo(cfg, fake_verify_one, {t: "" for t in task_ids}, rng=random.Random(0)) + + assert len(results) == 2 + assert results[0].reference_ids == ["a", "b", "c", "d"] + assert len(results[1].reference_ids) == 2 + assert results[1].eval_elo is not None + + summary = stage_results_to_dict(results) + assert summary["num_stages"] == 2 + assert summary["final_eval_elo"] == results[1].eval_elo + + def test_stage_results_to_dict_empty(self) -> None: + assert stage_results_to_dict([])["final_eval_elo"] is None + + +class TestParseStage: + def test_tasks_only(self) -> None: + s = _parse_stage("5") + assert (s.num_tasks, s.num_models, s.seed) == (5, None, None) + + def test_tasks_and_models(self) -> None: + s = _parse_stage("88:4") + assert (s.num_tasks, s.num_models, s.seed) == (88, 4, None) + + def test_all_models_keyword_and_seed(self) -> None: + s = _parse_stage("5:all:7") + assert (s.num_tasks, s.num_models, s.seed) == (5, None, 7) + + @pytest.mark.parametrize("bad", ["", "x", "5:y", "5:4:z"]) + def test_invalid(self, bad: str) -> None: + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + _parse_stage(bad) + + +class TestLoadReferenceElos: + def test_inline_json(self) -> None: + assert _load_reference_elos('{"a": 1500, "b": 1200}') == {"a": 1500.0, "b": 1200.0} + + def test_from_file(self, tmp_path: Path) -> None: + f = tmp_path / "refs.json" + f.write_text(json.dumps({"a": 1000})) + assert _load_reference_elos(f"@{f}") == {"a": 1000.0} + + @pytest.mark.parametrize("bad", ["[]", "{}", '"x"']) + def test_invalid(self, bad: str) -> None: + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + _load_reference_elos(bad) + + +class TestCliMain: + def _setup(self, tmp_path: Path): + _make_cache(tmp_path, ["a", "b"]) + prompts = tmp_path / "bench.jsonl" + prompts.write_text(json.dumps({"task_id": "a", "prompt": "p"}) + "\n") + refs = tmp_path / "refs.json" + refs.write_text(json.dumps({"a": 1000.0, "b": 1200.0})) + return prompts, refs + + def test_main_writes_summary(self, tmp_path: Path, monkeypatch, capsys) -> None: + prompts, refs = self._setup(tmp_path) + captured = {} + + def fake_run(config, verify_one, task_prompts, *, rng=None, producer=None, on_event=None, progress=None): + captured["config"] = config + captured["rng"] = rng + return [ + StageResult( + stage_index=0, + task_ids=["a"], + reference_ids=["a", "b"], + per_reference={}, + eval_elo=1234.0, + normalized_elo=0.5, + num_references=2, + ) + ] + + monkeypatch.setattr(driver, "run_multistage_elo", fake_run) + out = tmp_path / "summary.json" + rc = main( + [ + "--server-url", + "http://localhost:9999", + "--eval-deliverables-dir", + str(tmp_path), + "--reference-elos", + f"@{refs}", + "--stage", + "5", + "--stage", + "12:1", + "--task-prompts", + str(prompts), + "--nested-tasks", + "--skip-missing", + "--seed", + "3", + "--output", + str(out), + ] + ) + assert rc == 0 + summary = json.loads(out.read_text()) + assert summary["final_eval_elo"] == 1234.0 + cfg = captured["config"] + assert [s.num_tasks for s in cfg.stages] == [5, 12] + assert cfg.stages[1].num_models == 1 + assert cfg.nested_tasks is True + assert cfg.produce_missing is False + assert cfg.reference_elos == {"a": 1000.0, "b": 1200.0} + assert isinstance(captured["rng"], random.Random) + + def test_main_to_stdout(self, tmp_path: Path, monkeypatch, capsys) -> None: + prompts, refs = self._setup(tmp_path) + monkeypatch.setattr(driver, "run_multistage_elo", lambda *a, **k: []) + rc = main( + [ + "--server-url", + "http://localhost:9999", + "--eval-deliverables-dir", + str(tmp_path), + "--reference-elos", + f"@{refs}", + "--stage", + "5", + "--task-prompts", + str(prompts), + ] + ) + assert rc == 0 + assert json.loads(capsys.readouterr().out)["num_stages"] == 0 + + def test_main_missing_eval_dir(self, tmp_path: Path, capsys) -> None: + _, refs = self._setup(tmp_path) + rc = main( + [ + "--server-url", + "http://x", + "--eval-deliverables-dir", + str(tmp_path / "nope"), + "--reference-elos", + f"@{refs}", + "--stage", + "5", + ] + ) + assert rc == 2 + assert "not found" in capsys.readouterr().err.lower() + + def test_main_missing_prompts(self, tmp_path: Path, capsys) -> None: + _, refs = self._setup(tmp_path) + rc = main( + [ + "--server-url", + "http://x", + "--eval-deliverables-dir", + str(tmp_path), + "--reference-elos", + f"@{refs}", + "--stage", + "5", + "--task-prompts", + str(tmp_path / "nope.jsonl"), + ] + ) + assert rc == 2 + assert "not found" in capsys.readouterr().err.lower() diff --git a/responses_api_agents/stirrup_agent/task_distribution.py b/responses_api_agents/stirrup_agent/task_distribution.py new file mode 100644 index 0000000000..4f661cc96b --- /dev/null +++ b/responses_api_agents/stirrup_agent/task_distribution.py @@ -0,0 +1,426 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Build a task distribution over one or more dataset columns. + +A *distribution* groups every task in a dataset by the value(s) of one or +more metadata columns (e.g. ``sector``, or ``sector`` + ``occupation``) and +records, for each group, the fraction of the dataset it covers and the list +of ``task_id``s that fall into it:: + + { + "Business, Finance & Operations": {"percentage": 0.05, "task_ids": ["a", "b"]}, + "Legal": {"percentage": 0.50, "task_ids": [...]}, + "Healthcare": {"percentage": 0.45, "task_ids": [...]} + } + +Datasets are the NeMo Gym Responses-API JSONL format: one task per line, with +the groupable columns living under ``responses_create_params.metadata``. + +The grouping logic is intentionally separated from the CLI so the resulting +distribution can later be reused to *sample* ``task_id``s (see +``sample_task_ids``). + +Usage:: + + # Full defaults: the prepared GDPVal dataset (220 tasks) + # (benchmarks/gdpval/data/gdpval_benchmark.jsonl) grouped by ``occupation``. + # Without --output the distribution is printed to stdout. + python -m responses_api_agents.stirrup_agent.task_distribution \ + --output occupation_distribution.json + + # --dataset defaults to the prepared GDPVal dataset when omitted. + python -m responses_api_agents.stirrup_agent.task_distribution \ + --column sector \ + --output sector_distribution.json + + # Composite key over multiple columns, explicit dataset: + python -m responses_api_agents.stirrup_agent.task_distribution \ + --dataset data/gdpval.jsonl --column sector --column occupation \ + --output sector_occupation_distribution.json +""" + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence + + +# Sentinel used when a row is missing one of the requested columns. +MISSING_VALUE = "" + +# Separator joining multiple column values into a single composite key. +DEFAULT_KEY_SEPARATOR = " | " + +# Column grouped on when ``--column`` is not specified. +DEFAULT_COLUMN = "occupation" + +# Repo root: this file is responses_api_agents/stirrup_agent/task_distribution.py. +_REPO_ROOT = Path(__file__).resolve().parents[2] + +# Candidate GDPVal dataset locations, in priority order. The first that exists +# is used when ``--dataset`` is not given. The prepared benchmark JSONL (written +# by ``gym eval prepare --benchmark gdpval``) is preferred; the agent-local +# ``data/gdpval.jsonl`` (written by setup_scripts/gdpval.sh) is a fallback. +# The synthetic ``example.jsonl`` is intentionally *not* a default so the +# command never silently computes a distribution over a single fake task. +DEFAULT_DATASET_CANDIDATES = ( + _REPO_ROOT / "benchmarks" / "gdpval" / "data" / "gdpval_benchmark.jsonl", + Path(__file__).resolve().parent / "data" / "gdpval.jsonl", +) + + +def resolve_default_dataset( + candidates: Optional[Sequence[Path]] = None, +) -> Optional[Path]: + """Return the first existing default GDPVal dataset, or ``None``. + + Used when the caller does not pass an explicit ``--dataset``; prefers the + prepared benchmark JSONL and falls back to agent-local datasets. + """ + if candidates is None: + candidates = DEFAULT_DATASET_CANDIDATES + for candidate in candidates: + if candidate.is_file(): + return candidate + return None + + +def _no_dataset_message() -> str: + """Actionable error shown when no dataset is specified and no default exists.""" + searched = "".join(f" - {c}\n" for c in DEFAULT_DATASET_CANDIDATES) + return ( + "No dataset specified and no default GDPVal dataset was found.\n" + f"\nSearched these default locations:\n{searched}" + "\nTo fix this, do one of the following:\n" + "\n 1. Prepare the GDPVal benchmark dataset (recommended). This downloads\n" + " the openai/gdpval dataset from HuggingFace and writes\n" + " benchmarks/gdpval/data/gdpval_benchmark.jsonl (220 tasks).\n" + "\n First activate the project virtualenv so the Gym CLI is on PATH\n" + " (the `gym`/`ng_*` commands live in .venv, not on your global PATH):\n" + "\n source .venv/bin/activate\n" + " export HF_TOKEN=\n" + "\n Then run the setup script (works on all installs):\n" + "\n bash responses_api_agents/stirrup_agent/setup_scripts/gdpval.sh\n" + "\n Or call a prepare CLI directly:\n" + "\n gym eval prepare --benchmark gdpval # newer installs\n" + " ng_prepare_benchmark '+config_paths=[benchmarks/gdpval/config.yaml]' # any install\n" + "\n 2. Pass an explicit dataset path with --dataset .\n" + "\nNote: the GDPVal dataset is gated on HuggingFace, so HF_TOKEN must be set\n" + "and your account must have access to https://huggingface.co/datasets/openai/gdpval.\n" + ) + + +def iter_dataset_rows(dataset_path: str | Path) -> Iterator[Dict[str, Any]]: + """Yield parsed JSON objects from a Responses-API JSONL dataset. + + Blank lines are skipped; malformed lines raise ``ValueError`` with the + 1-based line number so the offending row is easy to find. + """ + path = Path(dataset_path) + with path.open("r", encoding="utf-8") as handle: + for line_no, line in enumerate(handle, start=1): + stripped = line.strip() + if not stripped: + continue + try: + yield json.loads(stripped) + except json.JSONDecodeError as exc: + raise ValueError(f"{path}:{line_no}: invalid JSON line: {exc}") from exc + + +def extract_metadata(row: Mapping[str, Any]) -> Dict[str, Any]: + """Return the ``responses_create_params.metadata`` dict for a row. + + Falls back to a top-level ``metadata`` key (and finally the row itself) + so the function also works on flatter dataset variants. + """ + params = row.get("responses_create_params") + if isinstance(params, Mapping): + metadata = params.get("metadata") + if isinstance(metadata, Mapping): + return dict(metadata) + metadata = row.get("metadata") + if isinstance(metadata, Mapping): + return dict(metadata) + return dict(row) + + +def compose_key( + metadata: Mapping[str, Any], + columns: Sequence[str], + *, + separator: str = DEFAULT_KEY_SEPARATOR, + missing_value: str = MISSING_VALUE, +) -> str: + """Build the distribution key for a row from one or more columns. + + Each column value is stringified; missing values become ``missing_value``. + Multiple columns are joined with ``separator`` into a composite key. + """ + parts: List[str] = [] + for column in columns: + value = metadata.get(column, None) + if value is None: + parts.append(missing_value) + else: + parts.append(str(value)) + return separator.join(parts) + + +def build_distribution( + rows: Iterable[Mapping[str, Any]], + columns: Sequence[str], + *, + task_id_column: str = "task_id", + separator: str = DEFAULT_KEY_SEPARATOR, + missing_value: str = MISSING_VALUE, + precision: Optional[int] = 6, +) -> Dict[str, Dict[str, Any]]: + """Compute the task distribution across ``columns``. + + Returns a mapping ``key -> {"percentage": float, "task_ids": [...]}`` where + ``percentage`` is the fraction (0..1) of all tasks that share that key and + ``task_ids`` lists every matching task in first-seen order. The mapping is + ordered by descending ``percentage`` (ties broken by key) for readability. + + ``percentage`` values are rounded to ``precision`` decimal places when + ``precision`` is not ``None``. Note that rounding can make the percentages + sum to slightly more or less than 1.0; the unrounded fractions always sum + to 1.0. + """ + if not columns: + raise ValueError("At least one column is required to build a distribution.") + + grouped: Dict[str, List[str]] = {} + total = 0 + for index, row in enumerate(rows): + metadata = extract_metadata(row) + key = compose_key(metadata, columns, separator=separator, missing_value=missing_value) + task_id = metadata.get(task_id_column) + # Fall back to a positional id so every task is still counted/listed + # even when the dataset lacks an explicit task-id column. + task_id_str = str(task_id) if task_id is not None else f"{task_id_column}_index_{index}" + grouped.setdefault(key, []).append(task_id_str) + total += 1 + + distribution: Dict[str, Dict[str, Any]] = {} + for key, task_ids in grouped.items(): + fraction = (len(task_ids) / total) if total else 0.0 + percentage = round(fraction, precision) if precision is not None else fraction + distribution[key] = {"percentage": percentage, "task_ids": task_ids} + + # Sort by descending share, then by key for stable, readable output. + ordered = dict( + sorted( + distribution.items(), + key=lambda item: (-len(item[1]["task_ids"]), item[0]), + ) + ) + return ordered + + +def build_distribution_from_dataset( + dataset_path: str | Path, + columns: Sequence[str], + *, + task_id_column: str = "task_id", + separator: str = DEFAULT_KEY_SEPARATOR, + missing_value: str = MISSING_VALUE, + precision: Optional[int] = 6, +) -> Dict[str, Dict[str, Any]]: + """Convenience wrapper: read a JSONL dataset and build its distribution.""" + return build_distribution( + iter_dataset_rows(dataset_path), + columns, + task_id_column=task_id_column, + separator=separator, + missing_value=missing_value, + precision=precision, + ) + + +def sample_task_ids( + distribution: Mapping[str, Mapping[str, Any]], + n: int, + *, + rng: Optional[random.Random] = None, + replace: bool = False, +) -> List[str]: + """Sample ``n`` ``task_id``s in proportion to a distribution's percentages. + + Each task id is drawn by first choosing a group weighted by its + ``percentage`` and then choosing a task id within that group. With + ``replace=False`` (default) the same task id is never returned twice and + ``n`` is capped at the total number of available task ids. + + This is the consumption-side counterpart to ``build_distribution`` and is + provided so the saved distribution file can directly drive task sampling. + """ + if n <= 0: + return [] + rng = rng or random.Random() + + keys = list(distribution.keys()) + weights = [float(distribution[key].get("percentage", 0.0)) for key in keys] + if not keys or sum(weights) <= 0: + return [] + + if replace: + sampled: List[str] = [] + for _ in range(n): + (chosen_key,) = rng.choices(keys, weights=weights, k=1) + task_ids = list(distribution[chosen_key].get("task_ids", [])) + if not task_ids: + continue + sampled.append(rng.choice(task_ids)) + return sampled + + # Without replacement: track remaining ids per group and renormalise. + remaining: Dict[str, List[str]] = {key: list(distribution[key].get("task_ids", [])) for key in keys} + total_available = sum(len(ids) for ids in remaining.values()) + target = min(n, total_available) + + sampled = [] + while len(sampled) < target: + live_keys = [key for key in keys if remaining[key]] + live_weights = [float(distribution[key].get("percentage", 0.0)) for key in live_keys] + if not live_keys or sum(live_weights) <= 0: + break + (chosen_key,) = rng.choices(live_keys, weights=live_weights, k=1) + bucket = remaining[chosen_key] + idx = rng.randrange(len(bucket)) + sampled.append(bucket.pop(idx)) + return sampled + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="task_distribution", + description=( + "Build a JSON distribution of tasks across one or more dataset " + "columns (e.g. sector, occupation) from a Responses-API JSONL dataset." + ), + ) + parser.add_argument( + "--dataset", + default=None, + help=( + "Path to the input JSONL dataset (one task per line). If omitted, " + "defaults to the prepared GDPVal dataset " + "(benchmarks/gdpval/data/gdpval_benchmark.jsonl), falling back to " + "the agent-local data/gdpval.jsonl or data/example.jsonl." + ), + ) + parser.add_argument( + "--column", + dest="columns", + action="append", + default=None, + metavar="COLUMN", + help=( + "Metadata column to group by. Repeat to group by a composite key " + "(e.g. --column sector --column occupation). " + f"Defaults to {DEFAULT_COLUMN!r} if not specified." + ), + ) + parser.add_argument( + "--output", + "-o", + default=None, + help="Path to write the distribution JSON. Defaults to stdout.", + ) + parser.add_argument( + "--task-id-column", + default="task_id", + help="Metadata column holding the task id (default: task_id).", + ) + parser.add_argument( + "--separator", + default=DEFAULT_KEY_SEPARATOR, + help=f"Separator joining multiple column values into one key (default: {DEFAULT_KEY_SEPARATOR!r}).", + ) + parser.add_argument( + "--missing-value", + default=MISSING_VALUE, + help=f"Placeholder for rows missing a column (default: {MISSING_VALUE!r}).", + ) + parser.add_argument( + "--precision", + type=int, + default=6, + help="Decimal places to round percentages to; use -1 for no rounding (default: 6).", + ) + parser.add_argument( + "--indent", + type=int, + default=2, + help="Indentation for the output JSON; use -1 for compact output (default: 2).", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = _build_arg_parser() + args = parser.parse_args(argv) + + if args.dataset is not None: + dataset_path = Path(args.dataset) + if not dataset_path.is_file(): + print(f"Dataset not found: {dataset_path}", file=sys.stderr) + return 2 + else: + dataset_path = resolve_default_dataset() + if dataset_path is None: + print(_no_dataset_message(), file=sys.stderr) + return 2 + print(f"Using default dataset: {dataset_path}", file=sys.stderr) + + columns = args.columns + if not columns: + columns = [DEFAULT_COLUMN] + print(f"No --column specified; defaulting to {DEFAULT_COLUMN!r}.", file=sys.stderr) + + precision = None if args.precision is not None and args.precision < 0 else args.precision + indent = None if args.indent is not None and args.indent < 0 else args.indent + + distribution = build_distribution_from_dataset( + dataset_path, + columns, + task_id_column=args.task_id_column, + separator=args.separator, + missing_value=args.missing_value, + precision=precision, + ) + + payload = json.dumps(distribution, indent=indent, ensure_ascii=False) + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(payload + "\n", encoding="utf-8") + total_tasks = sum(len(entry["task_ids"]) for entry in distribution.values()) + print( + f"Wrote distribution over {columns} ({len(distribution)} groups, {total_tasks} tasks) to {out_path}", + file=sys.stderr, + ) + else: + print(payload) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/responses_api_agents/stirrup_agent/tests/test_task_distribution.py b/responses_api_agents/stirrup_agent/tests/test_task_distribution.py new file mode 100644 index 0000000000..6250e5194e --- /dev/null +++ b/responses_api_agents/stirrup_agent/tests/test_task_distribution.py @@ -0,0 +1,282 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import random +from pathlib import Path + +import pytest + +from responses_api_agents.stirrup_agent import task_distribution as td +from responses_api_agents.stirrup_agent.task_distribution import ( + MISSING_VALUE, + build_distribution, + build_distribution_from_dataset, + compose_key, + extract_metadata, + iter_dataset_rows, + main, + resolve_default_dataset, + sample_task_ids, +) + + +def _row(task_id: str, **metadata) -> dict: + return {"responses_create_params": {"input": "", "metadata": {"task_id": task_id, **metadata}}} + + +def _write_jsonl(path: Path, rows) -> Path: + path.write_text("\n".join(json.dumps(r) for r in rows) + "\n", encoding="utf-8") + return path + + +class TestExtractMetadata: + def test_responses_create_params_metadata(self) -> None: + row = _row("t1", sector="Legal") + assert extract_metadata(row) == {"task_id": "t1", "sector": "Legal"} + + def test_top_level_metadata_fallback(self) -> None: + row = {"metadata": {"task_id": "t1", "sector": "Legal"}} + assert extract_metadata(row) == {"task_id": "t1", "sector": "Legal"} + + def test_row_itself_fallback(self) -> None: + row = {"task_id": "t1", "sector": "Legal"} + assert extract_metadata(row) == {"task_id": "t1", "sector": "Legal"} + + def test_non_mapping_params_falls_through(self) -> None: + row = {"responses_create_params": "oops", "metadata": {"task_id": "t1"}} + assert extract_metadata(row) == {"task_id": "t1"} + + +class TestComposeKey: + def test_single_column(self) -> None: + assert compose_key({"sector": "Legal"}, ["sector"]) == "Legal" + + def test_composite_key(self) -> None: + meta = {"sector": "Legal", "occupation": "Lawyer"} + assert compose_key(meta, ["sector", "occupation"]) == "Legal | Lawyer" + + def test_missing_value_placeholder(self) -> None: + assert compose_key({}, ["sector"]) == MISSING_VALUE + + def test_custom_separator(self) -> None: + meta = {"a": "x", "b": "y"} + assert compose_key(meta, ["a", "b"], separator="::") == "x::y" + + def test_non_string_value_is_stringified(self) -> None: + assert compose_key({"n": 5}, ["n"]) == "5" + + +class TestBuildDistribution: + def test_percentages_and_task_ids(self) -> None: + rows = [ + _row("a", sector="Legal"), + _row("b", sector="Legal"), + _row("c", sector="Healthcare"), + _row("d", sector="Finance"), + ] + dist = build_distribution(rows, ["sector"]) + assert dist["Legal"]["percentage"] == 0.5 + assert dist["Legal"]["task_ids"] == ["a", "b"] + assert dist["Healthcare"]["percentage"] == 0.25 + assert dist["Finance"]["task_ids"] == ["d"] + + def test_ordering_is_descending_by_share(self) -> None: + rows = [ + _row("a", sector="Legal"), + _row("b", sector="Legal"), + _row("c", sector="Healthcare"), + ] + assert list(build_distribution(rows, ["sector"]).keys()) == ["Legal", "Healthcare"] + + def test_percentages_sum_to_one_unrounded(self) -> None: + rows = [_row(str(i), sector=s) for i, s in enumerate(["a", "a", "b", "c", "c", "c", "d"])] + dist = build_distribution(rows, ["sector"], precision=None) + assert pytest.approx(sum(e["percentage"] for e in dist.values())) == 1.0 + + def test_composite_columns(self) -> None: + rows = [ + _row("a", sector="Legal", occupation="Lawyer"), + _row("b", sector="Legal", occupation="Paralegal"), + ] + dist = build_distribution(rows, ["sector", "occupation"]) + assert set(dist.keys()) == {"Legal | Lawyer", "Legal | Paralegal"} + + def test_empty_rows_yields_empty(self) -> None: + assert build_distribution([], ["sector"]) == {} + + def test_missing_column_grouped_under_placeholder(self) -> None: + rows = [_row("a"), _row("b", sector="Legal")] + dist = build_distribution(rows, ["sector"]) + assert MISSING_VALUE in dist + assert dist[MISSING_VALUE]["task_ids"] == ["a"] + + def test_missing_task_id_uses_positional_fallback(self) -> None: + rows = [{"responses_create_params": {"metadata": {"sector": "Legal"}}}] + dist = build_distribution(rows, ["sector"]) + assert dist["Legal"]["task_ids"] == ["task_id_index_0"] + + def test_requires_columns(self) -> None: + with pytest.raises(ValueError): + build_distribution([_row("a", sector="Legal")], []) + + def test_precision_rounding(self) -> None: + rows = [_row(str(i), sector="a" if i == 0 else "b") for i in range(3)] + dist = build_distribution(rows, ["sector"], precision=2) + assert dist["b"]["percentage"] == 0.67 + + +class TestIterAndDatasetWrapper: + def test_iter_skips_blank_lines(self, tmp_path: Path) -> None: + path = tmp_path / "d.jsonl" + path.write_text(json.dumps(_row("a", sector="Legal")) + "\n\n", encoding="utf-8") + assert len(list(iter_dataset_rows(path))) == 1 + + def test_iter_raises_on_bad_json(self, tmp_path: Path) -> None: + path = tmp_path / "d.jsonl" + path.write_text("{not json}\n", encoding="utf-8") + with pytest.raises(ValueError, match="invalid JSON"): + list(iter_dataset_rows(path)) + + def test_build_from_dataset(self, tmp_path: Path) -> None: + path = _write_jsonl(tmp_path / "d.jsonl", [_row("a", sector="Legal"), _row("b", sector="Legal")]) + dist = build_distribution_from_dataset(path, ["sector"]) + assert dist["Legal"]["percentage"] == 1.0 + + +class TestSampleTaskIds: + def _dist(self): + return { + "Legal": {"percentage": 0.5, "task_ids": ["a", "b"]}, + "Healthcare": {"percentage": 0.5, "task_ids": ["c", "d"]}, + } + + def test_zero_or_negative_returns_empty(self) -> None: + assert sample_task_ids(self._dist(), 0) == [] + assert sample_task_ids(self._dist(), -3) == [] + + def test_without_replacement_no_duplicates(self) -> None: + rng = random.Random(0) + sampled = sample_task_ids(self._dist(), 3, rng=rng) + assert len(sampled) == 3 + assert len(set(sampled)) == 3 + + def test_without_replacement_capped_at_total(self) -> None: + sampled = sample_task_ids(self._dist(), 100, rng=random.Random(1)) + assert sorted(sampled) == ["a", "b", "c", "d"] + + def test_with_replacement_allows_more_than_total(self) -> None: + sampled = sample_task_ids(self._dist(), 10, rng=random.Random(2), replace=True) + assert len(sampled) == 10 + + def test_empty_distribution_returns_empty(self) -> None: + assert sample_task_ids({}, 5) == [] + + def test_zero_weight_distribution_returns_empty(self) -> None: + dist = {"x": {"percentage": 0.0, "task_ids": ["a"]}} + assert sample_task_ids(dist, 5) == [] + assert sample_task_ids(dist, 5, replace=True) == [] + + def test_with_replacement_skips_empty_groups(self) -> None: + dist = {"x": {"percentage": 1.0, "task_ids": []}} + assert sample_task_ids(dist, 3, rng=random.Random(3), replace=True) == [] + + +class TestResolveDefaultDataset: + def test_returns_first_existing(self, tmp_path: Path) -> None: + missing = tmp_path / "missing.jsonl" + present = _write_jsonl(tmp_path / "present.jsonl", [_row("a", sector="Legal")]) + assert resolve_default_dataset([missing, present]) == present + + def test_priority_order(self, tmp_path: Path) -> None: + first = _write_jsonl(tmp_path / "first.jsonl", [_row("a", sector="Legal")]) + second = _write_jsonl(tmp_path / "second.jsonl", [_row("b", sector="Legal")]) + assert resolve_default_dataset([first, second]) == first + + def test_returns_none_when_nothing_exists(self, tmp_path: Path) -> None: + assert resolve_default_dataset([tmp_path / "a.jsonl", tmp_path / "b.jsonl"]) is None + + +class TestMain: + def test_uses_default_dataset_when_omitted(self, tmp_path: Path, capsys, monkeypatch) -> None: + default_ds = _write_jsonl(tmp_path / "gdpval.jsonl", [_row("a", sector="Legal")]) + monkeypatch.setattr(td, "DEFAULT_DATASET_CANDIDATES", (tmp_path / "missing.jsonl", default_ds)) + rc = main(["--column", "sector"]) + assert rc == 0 + captured = capsys.readouterr() + assert str(default_ds) in captured.err + assert json.loads(captured.out)["Legal"]["percentage"] == 1.0 + + def test_errors_when_no_default_and_none_specified(self, tmp_path: Path, capsys, monkeypatch) -> None: + monkeypatch.setattr(td, "DEFAULT_DATASET_CANDIDATES", (tmp_path / "missing.jsonl",)) + rc = main(["--column", "sector"]) + assert rc == 2 + err = capsys.readouterr().err + assert "no default gdpval dataset was found" in err.lower() + assert "gym eval prepare --benchmark gdpval" in err + assert "--dataset" in err + + def test_defaults_to_occupation_column(self, tmp_path: Path, capsys) -> None: + dataset = _write_jsonl( + tmp_path / "d.jsonl", + [_row("a", occupation="Lawyer"), _row("b", occupation="Lawyer"), _row("c", occupation="Nurse")], + ) + rc = main(["--dataset", str(dataset)]) + assert rc == 0 + captured = capsys.readouterr() + assert "defaulting to 'occupation'" in captured.err + data = json.loads(captured.out) + assert data["Lawyer"]["task_ids"] == ["a", "b"] + assert data["Nurse"]["percentage"] == pytest.approx(1 / 3) + + def test_errors_when_specified_dataset_missing(self, tmp_path: Path, capsys) -> None: + rc = main(["--dataset", str(tmp_path / "nope.jsonl"), "--column", "sector"]) + assert rc == 2 + assert "Dataset not found" in capsys.readouterr().err + + def test_writes_output_file(self, tmp_path: Path, capsys) -> None: + dataset = _write_jsonl( + tmp_path / "d.jsonl", + [_row("a", sector="Legal"), _row("b", sector="Legal"), _row("c", sector="Healthcare")], + ) + out = tmp_path / "dist.json" + rc = main(["--dataset", str(dataset), "--column", "sector", "--output", str(out)]) + assert rc == 0 + data = json.loads(out.read_text()) + assert data["Legal"]["task_ids"] == ["a", "b"] + assert "3 tasks" in capsys.readouterr().err + + def test_stdout_when_no_output(self, tmp_path: Path, capsys) -> None: + dataset = _write_jsonl(tmp_path / "d.jsonl", [_row("a", sector="Legal")]) + rc = main(["--dataset", str(dataset), "--column", "sector"]) + assert rc == 0 + assert json.loads(capsys.readouterr().out)["Legal"]["percentage"] == 1.0 + + def test_no_rounding_and_compact(self, tmp_path: Path, capsys) -> None: + dataset = _write_jsonl( + tmp_path / "d.jsonl", [_row("a", sector="x"), _row("b", sector="y"), _row("c", sector="y")] + ) + rc = main(["--dataset", str(dataset), "--column", "sector", "--precision", "-1", "--indent", "-1"]) + assert rc == 0 + out = capsys.readouterr().out + assert "\n " not in out # compact (no indentation) + assert json.loads(out)["y"]["percentage"] == pytest.approx(2 / 3) + + def test_composite_columns_cli(self, tmp_path: Path, capsys) -> None: + dataset = _write_jsonl( + tmp_path / "d.jsonl", + [_row("a", sector="Legal", occupation="Lawyer")], + ) + rc = main(["--dataset", str(dataset), "--column", "sector", "--column", "occupation"]) + assert rc == 0 + assert "Legal | Lawyer" in json.loads(capsys.readouterr().out)