Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D
results,
result_strs,
) = self._load_from_cache(config)
persisted_rows = list(rows)
persisted_results = list(results)
else:
if config.resume_from_cache:
if not output_fpath.exists():
Expand All @@ -461,6 +463,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D
rows: List[Dict] = []
results: List[Dict] = []
result_strs: List[List[str]] = []
persisted_rows: List[Dict] = []
persisted_results: List[Dict] = []

input_rows = self._preprocess_rows_from_config(config)
# Returned rows are sorted by (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])
Expand Down Expand Up @@ -513,6 +517,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D
# Success → main jsonl.
results_file.write(serialized + b"\n")
results_file.flush()
persisted_rows.append(row)
persisted_results.append(result)

counts_left[row[AGENT_REF_KEY_NAME]["name"]] -= 1
if counts_left[row[AGENT_REF_KEY_NAME]["name"]] <= 0:
Expand Down Expand Up @@ -540,8 +546,12 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D
print("Sorting results to ensure consistent ordering")
rows.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]))
results.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]))
persisted_rows.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]))
persisted_results.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]))

# Compute and write aggregate metrics via /aggregate_metrics on each agent server
# Compute and write aggregate metrics via /aggregate_metrics using only the
# rows written to the main rollouts jsonl so runtime aggregation matches
# `gym eval aggregate`.
if config.disable_aggregation:
print(
"Skipping aggregate-metrics computation because disable_aggregation=True. "
Expand All @@ -550,7 +560,9 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D
aggregate_metrics_fpath = None
else:
print("Computing aggregate metrics")
aggregate_metrics_fpath = await self._call_aggregate_metrics(results, rows, output_fpath)
aggregate_metrics_fpath = await self._call_aggregate_metrics(
persisted_results, persisted_rows, output_fpath
)

print(f"""Finished rollout collection! View results at:
Fully materialized inputs: {config.materialized_jsonl_fpath}
Expand Down
66 changes: 66 additions & 0 deletions tests/unit_tests/test_rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
from nemo_gym.reward_profile import compute_aggregate_metrics
from nemo_gym.rollout_collection import (
_DEFAULT_MAX_ROLLOUT_ATTEMPTS,
NG_FAILURE_CLASS_KEY,
NG_NO_PERSIST_KEY,
RolloutAggregationConfig,
RolloutAggregationHelper,
RolloutCollectionConfig,
RolloutCollectionHelper,
_expand_input_glob,
_failures_path_for,
_get_max_rollout_attempts,
_rollout_request_debug_summary,
)
Expand Down Expand Up @@ -693,6 +696,69 @@ async def _call_aggregate_metrics(self, results, rows, output_fpath):

assert expected_results == actual_returned_results

async def test_run_from_config_aggregate_metrics_excludes_non_persisted_rows(self, tmp_path: Path) -> None:
input_jsonl_fpath = tmp_path / "input.jsonl"
samples = [
json.dumps({"responses_create_params": {"input": []}, "agent_ref": {"name": "my agent name"}, "x": i})
for i in range(3)
]
input_jsonl_fpath.write_text("\n".join(samples) + "\n")
output_jsonl_fpath = tmp_path / "output.jsonl"

config = RolloutCollectionConfig(
input_jsonl_fpath=str(input_jsonl_fpath),
output_jsonl_fpath=str(output_jsonl_fpath),
limit=3,
num_repeats=1,
)

captured: dict[str, list[dict]] = {}

class TestRolloutCollectionHelper(RolloutCollectionHelper):
def run_examples(
self,
examples: list[dict],
*args,
**kwargs,
):
futures = []
for example in examples:
future = Future()
result = {
"response": {"usage": {"abc usage": example["x"] + 1}},
"case": f"case-{example['x']}",
}
if example["x"] == 1:
result[NG_FAILURE_CLASS_KEY] = "verify_failed"
elif example["x"] == 2:
result[NG_NO_PERSIST_KEY] = True
future.set_result((example, result))
futures.append(future)
return futures

async def _call_aggregate_metrics(self, results, rows, output_fpath):
captured["results"] = results
captured["rows"] = rows
metrics_fpath = output_fpath.with_stem(output_fpath.stem + "_aggregate_metrics").with_suffix(".json")
metrics_fpath.write_text("[]")
return metrics_fpath

actual_returned_results = await TestRolloutCollectionHelper().run_from_config(config)

assert [result["case"] for result in actual_returned_results] == ["case-0", "case-1", "case-2"]
assert [result["case"] for result in captured["results"]] == ["case-0"]
assert [row["x"] for row in captured["rows"]] == [0]

with output_jsonl_fpath.open() as f:
actual_written_results = [json.loads(line) for line in f]
assert [result["case"] for result in actual_written_results] == ["case-0"]

failures_fpath = _failures_path_for(output_jsonl_fpath)
with failures_fpath.open() as f:
actual_failure_results = [json.loads(line) for line in f]
assert [result["case"] for result in actual_failure_results] == ["case-1"]
assert actual_failure_results[0][NG_FAILURE_CLASS_KEY] == "verify_failed"

def test_load_from_cache(self, tmp_path: Path) -> None:
input_jsonl_fpath = tmp_path / "input.jsonl"
materialized_inputs_jsonl_fpath = tmp_path / "output_materialized_inputs.jsonl"
Expand Down