Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def log_rollout_data(
"rollout_top_p_token_ids",
"rollout_top_p_token_offsets",
"rollout_routed_experts",
"global_raw_reward",
"global_batch_sizes",
"num_microbatches",
"micro_batch_indices",
Expand Down Expand Up @@ -480,12 +481,10 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -
"""
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():
log_dict = {}
for key, val in rollout_data.items():
if key != "raw_reward":
continue

raw_rewards = rollout_data.get("global_raw_reward", rollout_data.get("raw_reward"))
if raw_rewards is not None:
log_dict |= compute_pass_rate(
flat_rewards=val,
flat_rewards=raw_rewards,
group_size=args.n_samples_per_prompt,
num_groups=args.rollout_batch_size,
)
Expand Down
5 changes: 5 additions & 0 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,5 +299,10 @@ def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
# save the seqlen of the whole rollout batch
Timer().seq_lens = total_lengths
rollout_data["total_lengths"] = [total_lengths[i] for i in partition]
if "raw_reward" in rollout_data:
raw_reward = rollout_data["raw_reward"]
if getattr(args, "log_passrate", False):
rollout_data["global_raw_reward"] = raw_reward
rollout_data["raw_reward"] = [raw_reward[i] for i in partition]

return rollout_data
62 changes: 62 additions & 0 deletions tests/test_process_rollout_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from types import SimpleNamespace

import pytest

from slime.utils import data as data_utils

NUM_GPUS = 0


def test_process_rollout_data_partitions_raw_reward(monkeypatch):
monkeypatch.setattr(data_utils.ray, "get", lambda value: value)
rollout_data_ref = [
SimpleNamespace(
inner={
"partition": [0, 2],
"total_lengths": [8, 16, 24],
"raw_reward": [0.0, 1.0, 0.5],
}
),
SimpleNamespace(inner={}),
]

rollout_data = data_utils.process_rollout_data(
SimpleNamespace(log_passrate=False),
rollout_data_ref,
dp_rank=0,
dp_size=2,
)

assert rollout_data["total_lengths"] == [8, 24]
assert rollout_data["raw_reward"] == [0.0, 0.5]
assert "global_raw_reward" not in rollout_data
assert "partition" not in rollout_data


def test_process_rollout_data_keeps_global_raw_reward_for_passrate(monkeypatch):
monkeypatch.setattr(data_utils.ray, "get", lambda value: value)
global_raw_reward = [0.0, 1.0, 0.0, 1.0]
rollout_data_ref = [
SimpleNamespace(
inner={
"partition": [1, 3],
"total_lengths": [8, 16, 24, 32],
"raw_reward": global_raw_reward,
}
),
SimpleNamespace(inner={}),
]

rollout_data = data_utils.process_rollout_data(
SimpleNamespace(log_passrate=True),
rollout_data_ref,
dp_rank=0,
dp_size=2,
)

assert rollout_data["raw_reward"] == [1.0, 1.0]
assert rollout_data["global_raw_reward"] is global_raw_reward


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading