diff --git a/nemo_gym/rollout_collection.py b/nemo_gym/rollout_collection.py index 617e50a54..c24d9aa18 100644 --- a/nemo_gym/rollout_collection.py +++ b/nemo_gym/rollout_collection.py @@ -447,6 +447,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D results, result_strs, ) = self._load_from_cache(config) + persisted_rows = list(rows) + persisted_results = list(results) else: if config.resume_from_cache: if not output_fpath.exists(): @@ -461,6 +463,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D rows: List[Dict] = [] results: List[Dict] = [] result_strs: List[List[str]] = [] + persisted_rows: List[Dict] = [] + persisted_results: List[Dict] = [] input_rows = self._preprocess_rows_from_config(config) # Returned rows are sorted by (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME]) @@ -513,6 +517,8 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D # Success → main jsonl. results_file.write(serialized + b"\n") results_file.flush() + persisted_rows.append(row) + persisted_results.append(result) counts_left[row[AGENT_REF_KEY_NAME]["name"]] -= 1 if counts_left[row[AGENT_REF_KEY_NAME]["name"]] <= 0: @@ -540,8 +546,12 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D print("Sorting results to ensure consistent ordering") rows.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) results.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) + persisted_rows.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) + persisted_results.sort(key=lambda r: (r[TASK_INDEX_KEY_NAME], r[ROLLOUT_INDEX_KEY_NAME])) - # Compute and write aggregate metrics via /aggregate_metrics on each agent server + # Compute and write aggregate metrics via /aggregate_metrics using only the + # rows written to the main rollouts jsonl so runtime aggregation matches + # `gym eval aggregate`. if config.disable_aggregation: print( "Skipping aggregate-metrics computation because disable_aggregation=True. " @@ -550,7 +560,9 @@ async def run_from_config(self, config: RolloutCollectionConfig) -> Tuple[List[D aggregate_metrics_fpath = None else: print("Computing aggregate metrics") - aggregate_metrics_fpath = await self._call_aggregate_metrics(results, rows, output_fpath) + aggregate_metrics_fpath = await self._call_aggregate_metrics( + persisted_results, persisted_rows, output_fpath + ) print(f"""Finished rollout collection! View results at: Fully materialized inputs: {config.materialized_jsonl_fpath} diff --git a/tests/unit_tests/test_rollout_collection.py b/tests/unit_tests/test_rollout_collection.py index c6db6da6a..78d83c5e2 100644 --- a/tests/unit_tests/test_rollout_collection.py +++ b/tests/unit_tests/test_rollout_collection.py @@ -29,11 +29,14 @@ from nemo_gym.reward_profile import compute_aggregate_metrics from nemo_gym.rollout_collection import ( _DEFAULT_MAX_ROLLOUT_ATTEMPTS, + NG_FAILURE_CLASS_KEY, + NG_NO_PERSIST_KEY, RolloutAggregationConfig, RolloutAggregationHelper, RolloutCollectionConfig, RolloutCollectionHelper, _expand_input_glob, + _failures_path_for, _get_max_rollout_attempts, _rollout_request_debug_summary, ) @@ -693,6 +696,69 @@ async def _call_aggregate_metrics(self, results, rows, output_fpath): assert expected_results == actual_returned_results + async def test_run_from_config_aggregate_metrics_excludes_non_persisted_rows(self, tmp_path: Path) -> None: + input_jsonl_fpath = tmp_path / "input.jsonl" + samples = [ + json.dumps({"responses_create_params": {"input": []}, "agent_ref": {"name": "my agent name"}, "x": i}) + for i in range(3) + ] + input_jsonl_fpath.write_text("\n".join(samples) + "\n") + output_jsonl_fpath = tmp_path / "output.jsonl" + + config = RolloutCollectionConfig( + input_jsonl_fpath=str(input_jsonl_fpath), + output_jsonl_fpath=str(output_jsonl_fpath), + limit=3, + num_repeats=1, + ) + + captured: dict[str, list[dict]] = {} + + class TestRolloutCollectionHelper(RolloutCollectionHelper): + def run_examples( + self, + examples: list[dict], + *args, + **kwargs, + ): + futures = [] + for example in examples: + future = Future() + result = { + "response": {"usage": {"abc usage": example["x"] + 1}}, + "case": f"case-{example['x']}", + } + if example["x"] == 1: + result[NG_FAILURE_CLASS_KEY] = "verify_failed" + elif example["x"] == 2: + result[NG_NO_PERSIST_KEY] = True + future.set_result((example, result)) + futures.append(future) + return futures + + async def _call_aggregate_metrics(self, results, rows, output_fpath): + captured["results"] = results + captured["rows"] = rows + metrics_fpath = output_fpath.with_stem(output_fpath.stem + "_aggregate_metrics").with_suffix(".json") + metrics_fpath.write_text("[]") + return metrics_fpath + + actual_returned_results = await TestRolloutCollectionHelper().run_from_config(config) + + assert [result["case"] for result in actual_returned_results] == ["case-0", "case-1", "case-2"] + assert [result["case"] for result in captured["results"]] == ["case-0"] + assert [row["x"] for row in captured["rows"]] == [0] + + with output_jsonl_fpath.open() as f: + actual_written_results = [json.loads(line) for line in f] + assert [result["case"] for result in actual_written_results] == ["case-0"] + + failures_fpath = _failures_path_for(output_jsonl_fpath) + with failures_fpath.open() as f: + actual_failure_results = [json.loads(line) for line in f] + assert [result["case"] for result in actual_failure_results] == ["case-1"] + assert actual_failure_results[0][NG_FAILURE_CLASS_KEY] == "verify_failed" + def test_load_from_cache(self, tmp_path: Path) -> None: input_jsonl_fpath = tmp_path / "input.jsonl" materialized_inputs_jsonl_fpath = tmp_path / "output_materialized_inputs.jsonl"