Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data
and not self.args.keep_old_actor
and not self.args.use_opd
and not self.args.use_routing_replay
and self.args.advantage_estimator != "gspo"
and self.args.is_level != "sequence"
)
if (
not self.args.use_rollout_logprobs or self.args.get_mismatch_metrics
Expand Down
12 changes: 6 additions & 6 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)

This function extracts rewards, log-probs, values, and masks from
`rollout_data`, computes KL divergences, then applies the chosen advantage
estimator. Supported methods: "grpo", "gspo", "cispo", "ppo",
estimator. Supported methods: "grpo", "ppo",
"reinforce_plus_plus", and "reinforce_plus_plus_baseline". When
`args.normalize_advantages` is True, advantages are whitened across the
data-parallel group using masked statistics.
Expand Down Expand Up @@ -713,7 +713,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
custom_adv_fn(args, rollout_data)
advantages, returns = rollout_data["advantages"], rollout_data["returns"]

elif args.advantage_estimator in ["grpo", "gspo", "cispo"]:
elif args.advantage_estimator == "grpo":
rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device)
returns = get_grpo_returns(rewards, kl)
# TODO: is the copy necessary?
Expand Down Expand Up @@ -928,7 +928,7 @@ def policy_loss_function(
train_log_probs_for_tis = [log_prob.detach() for log_prob in log_probs]

# Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering
need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo"
need_full_log_probs = args.use_opsm or args.is_level == "sequence"

full_log_probs = None
full_old_log_probs = None
Expand Down Expand Up @@ -956,8 +956,8 @@ def policy_loss_function(
loss_masks=batch["loss_masks"],
)

# Compute KL divergence (GSPO uses sequence-level KL, others use per-token KL)
if args.advantage_estimator == "gspo":
# Compute KL divergence (sequence-level IS granularity uses GSPO's per-sequence KL, token-level uses per-token KL)
if args.is_level == "sequence":
ppo_kl = compute_gspo_kl(
full_log_probs=full_log_probs,
full_old_log_probs=full_old_log_probs,
Expand All @@ -971,7 +971,7 @@ def policy_loss_function(
log_probs = torch.cat(log_probs, dim=0)
ppo_kl = old_log_probs - log_probs

if args.advantage_estimator == "cispo":
if args.policy_loss == "cispo":
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high)
else:
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
Expand Down
4 changes: 2 additions & 2 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):

raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
if (
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"]
self.args.advantage_estimator in ["grpo", "reinforce_plus_plus_baseline"]
and self.args.rewards_normalization
):
# group norm
Expand All @@ -702,7 +702,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
mean = rewards.mean(dim=-1, keepdim=True)
rewards = rewards - mean

if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization:
if self.args.advantage_estimator == "grpo" and self.args.grpo_std_normalization:
std = rewards.std(dim=-1, keepdim=True)
rewards = rewards / (std + 1e-6)

Expand Down
48 changes: 45 additions & 3 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,32 @@ def add_algo_arguments(parser):
],
default="grpo",
help=(
"Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal "
"to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator."
"Advantage (credit-assignment) estimator. This axis is orthogonal to the surrogate "
"(--policy-loss) and the IS granularity (--is-level). Note: on-policy distillation (OPD) "
"is also orthogonal; use --opd-kl-coef > 0 to enable OPD on top of any estimator. "
"DEPRECATED values 'gspo' and 'cispo' are accepted for one release and remapped to "
"'--is-level sequence' and '--policy-loss cispo' respectively."
),
)
parser.add_argument(
"--policy-loss",
type=str,
choices=["ppo", "cispo"],
default="ppo",
help=(
"Policy-loss surrogate (the bounding rule that owns the importance weight), orthogonal "
"to --advantage-estimator and --is-level. 'ppo' is the clipped-ratio objective; 'cispo' "
"is the CISPO score-form objective (MiniMax-M1, https://arxiv.org/abs/2506.13585)."
),
)
parser.add_argument(
"--is-level",
type=str,
choices=["token", "sequence"],
default="token",
help=(
"Granularity of the importance-sampling ratio in the policy loss. 'token' is standard "
"per-token PPO; 'sequence' is the GSPO sequence-level ratio (https://arxiv.org/abs/2507.18071)."
),
)
parser.add_argument(
Expand Down Expand Up @@ -1739,6 +1763,24 @@ def _validate_update_weight_args(args) -> None:
def slime_validate_args(args):
args.eval_datasets = _resolve_eval_datasets(args)

# Deprecation shim: --advantage-estimator used to also select the surrogate ("cispo") and the
# IS granularity ("gspo"). These are now the orthogonal --policy-loss / --is-level axes. Remap the
# legacy values for one release (then drop them from the --advantage-estimator choices above).
_LEGACY_ADVANTAGE_ALIASES = {
"gspo": ("is_level", "sequence"),
"cispo": ("policy_loss", "cispo"),
}
if args.advantage_estimator in _LEGACY_ADVANTAGE_ALIASES:
field, value = _LEGACY_ADVANTAGE_ALIASES[args.advantage_estimator]
warnings.warn(
f"--advantage-estimator {args.advantage_estimator} is deprecated and will be removed; it now "
f"maps to '--advantage-estimator grpo --{field.replace('_', '-')} {value}'.",
FutureWarning,
stacklevel=2,
)
setattr(args, field, value)
args.advantage_estimator = "grpo"

if args.kl_coef != 0 or args.use_kl_loss:
if not os.path.exists(args.ref_load):
raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.")
Expand Down Expand Up @@ -1847,7 +1889,7 @@ def slime_validate_args(args):
if args.eps_clip_high is None:
args.eps_clip_high = args.eps_clip

if args.advantage_estimator == "cispo" and args.eps_clip < 1.0:
if args.policy_loss == "cispo" and args.eps_clip < 1.0:
logger.warning(
"CISPO is canonically single-sided, but --eps-clip=%s keeps the lower clip bound %s active. "
"Set --eps-clip 1.0 (and tune --eps-clip-high, e.g. 4.0) for the canonical wide setting.",
Expand Down
47 changes: 47 additions & 0 deletions tests/test_megatron_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def make_slime_validate_args(**overrides):
save=None,
kl_loss_coef=0,
advantage_estimator="grpo",
policy_loss="ppo",
is_level="token",
normalize_advantages=False,
use_rollout_logprobs=False,
use_tis=False,
Expand Down Expand Up @@ -381,5 +383,50 @@ def test_update_weight_delta_rejects_unknown_transport(monkeypatch):
module._validate_update_weight_args(args)


@pytest.mark.unit
def test_legacy_gspo_advantage_estimator_maps_to_sequence_is_level(monkeypatch):
module = load_slime_arguments_module(monkeypatch)
args = make_slime_validate_args(advantage_estimator="gspo")

with pytest.warns(FutureWarning, match="--advantage-estimator gspo is deprecated"):
module.slime_validate_args(args)

assert args.advantage_estimator == "grpo"
assert args.is_level == "sequence"
assert args.policy_loss == "ppo"


@pytest.mark.unit
def test_legacy_cispo_advantage_estimator_maps_to_cispo_policy_loss(monkeypatch):
module = load_slime_arguments_module(monkeypatch)
# eps_clip default 0.2 < 1.0 keeps the CISPO single-sided warning live; assert it still fires
# off the remapped --policy-loss axis (regression guard for the shim ordering).
args = make_slime_validate_args(advantage_estimator="cispo")

with pytest.warns(FutureWarning, match="--advantage-estimator cispo is deprecated"):
module.slime_validate_args(args)

assert args.advantage_estimator == "grpo"
assert args.policy_loss == "cispo"
assert args.is_level == "token"


@pytest.mark.unit
def test_non_legacy_advantage_estimator_is_untouched(monkeypatch):
module = load_slime_arguments_module(monkeypatch)
args = make_slime_validate_args(advantage_estimator="grpo", policy_loss="cispo", is_level="sequence")

import warnings as _warnings

with _warnings.catch_warnings():
_warnings.simplefilter("error", FutureWarning)
module.slime_validate_args(args)

# The orthogonal axes are passed through verbatim; no legacy remap.
assert args.advantage_estimator == "grpo"
assert args.policy_loss == "cispo"
assert args.is_level == "sequence"


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading