From 04fcc892dbe01307d7d2ffbc9b5e9bdb21ac7c7b Mon Sep 17 00:00:00 2001 From: Jiang020609 <190608333+Jiang020609@users.noreply.github.com> Date: Sat, 30 May 2026 20:40:20 +0800 Subject: [PATCH] fix(logging): partition raw rewards for correct samples --- slime/backends/megatron_utils/data.py | 9 ++-- slime/utils/data.py | 5 +++ tests/test_process_rollout_data.py | 62 +++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 5 deletions(-) create mode 100644 tests/test_process_rollout_data.py diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 00f319928f..2b04472139 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -287,6 +287,7 @@ def log_rollout_data( "rollout_top_p_token_ids", "rollout_top_p_token_offsets", "rollout_routed_experts", + "global_raw_reward", "global_batch_sizes", "num_microbatches", "micro_batch_indices", @@ -480,12 +481,10 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - """ if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): log_dict = {} - for key, val in rollout_data.items(): - if key != "raw_reward": - continue - + raw_rewards = rollout_data.get("global_raw_reward", rollout_data.get("raw_reward")) + if raw_rewards is not None: log_dict |= compute_pass_rate( - flat_rewards=val, + flat_rewards=raw_rewards, group_size=args.n_samples_per_prompt, num_groups=args.rollout_batch_size, ) diff --git a/slime/utils/data.py b/slime/utils/data.py index 0d26b6dda5..a35fd96697 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -299,5 +299,10 @@ def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): # save the seqlen of the whole rollout batch Timer().seq_lens = total_lengths rollout_data["total_lengths"] = [total_lengths[i] for i in partition] + if "raw_reward" in rollout_data: + raw_reward = rollout_data["raw_reward"] + if getattr(args, "log_passrate", False): + rollout_data["global_raw_reward"] = raw_reward + rollout_data["raw_reward"] = [raw_reward[i] for i in partition] return rollout_data diff --git a/tests/test_process_rollout_data.py b/tests/test_process_rollout_data.py new file mode 100644 index 0000000000..eeb61ce1d7 --- /dev/null +++ b/tests/test_process_rollout_data.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace + +import pytest + +from slime.utils import data as data_utils + +NUM_GPUS = 0 + + +def test_process_rollout_data_partitions_raw_reward(monkeypatch): + monkeypatch.setattr(data_utils.ray, "get", lambda value: value) + rollout_data_ref = [ + SimpleNamespace( + inner={ + "partition": [0, 2], + "total_lengths": [8, 16, 24], + "raw_reward": [0.0, 1.0, 0.5], + } + ), + SimpleNamespace(inner={}), + ] + + rollout_data = data_utils.process_rollout_data( + SimpleNamespace(log_passrate=False), + rollout_data_ref, + dp_rank=0, + dp_size=2, + ) + + assert rollout_data["total_lengths"] == [8, 24] + assert rollout_data["raw_reward"] == [0.0, 0.5] + assert "global_raw_reward" not in rollout_data + assert "partition" not in rollout_data + + +def test_process_rollout_data_keeps_global_raw_reward_for_passrate(monkeypatch): + monkeypatch.setattr(data_utils.ray, "get", lambda value: value) + global_raw_reward = [0.0, 1.0, 0.0, 1.0] + rollout_data_ref = [ + SimpleNamespace( + inner={ + "partition": [1, 3], + "total_lengths": [8, 16, 24, 32], + "raw_reward": global_raw_reward, + } + ), + SimpleNamespace(inner={}), + ] + + rollout_data = data_utils.process_rollout_data( + SimpleNamespace(log_passrate=True), + rollout_data_ref, + dp_rank=0, + dp_size=2, + ) + + assert rollout_data["raw_reward"] == [1.0, 1.0] + assert rollout_data["global_raw_reward"] is global_raw_reward + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))