diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 6bdc2dd3fe..614b424799 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -472,7 +472,7 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data and not self.args.keep_old_actor and not self.args.use_opd and not self.args.use_routing_replay - and self.args.advantage_estimator != "gspo" + and self.args.is_level != "sequence" ) if ( not self.args.use_rollout_logprobs or self.args.get_mismatch_metrics diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..75c4abf335 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -659,7 +659,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) This function extracts rewards, log-probs, values, and masks from `rollout_data`, computes KL divergences, then applies the chosen advantage - estimator. Supported methods: "grpo", "gspo", "cispo", "ppo", + estimator. Supported methods: "grpo", "ppo", "reinforce_plus_plus", and "reinforce_plus_plus_baseline". When `args.normalize_advantages` is True, advantages are whitened across the data-parallel group using masked statistics. @@ -713,7 +713,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) custom_adv_fn(args, rollout_data) advantages, returns = rollout_data["advantages"], rollout_data["returns"] - elif args.advantage_estimator in ["grpo", "gspo", "cispo"]: + elif args.advantage_estimator == "grpo": rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device) returns = get_grpo_returns(rewards, kl) # TODO: is the copy necessary? @@ -928,7 +928,7 @@ def policy_loss_function( train_log_probs_for_tis = [log_prob.detach() for log_prob in log_probs] # Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering - need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo" + need_full_log_probs = args.use_opsm or args.is_level == "sequence" full_log_probs = None full_old_log_probs = None @@ -956,8 +956,8 @@ def policy_loss_function( loss_masks=batch["loss_masks"], ) - # Compute KL divergence (GSPO uses sequence-level KL, others use per-token KL) - if args.advantage_estimator == "gspo": + # Compute KL divergence (sequence-level IS granularity uses GSPO's per-sequence KL, token-level uses per-token KL) + if args.is_level == "sequence": ppo_kl = compute_gspo_kl( full_log_probs=full_log_probs, full_old_log_probs=full_old_log_probs, @@ -971,7 +971,7 @@ def policy_loss_function( log_probs = torch.cat(log_probs, dim=0) ppo_kl = old_log_probs - log_probs - if args.advantage_estimator == "cispo": + if args.policy_loss == "cispo": pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high) else: pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index bc70e6257c..e667ec2f41 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -689,7 +689,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): raw_rewards = [sample.get_reward_value(self.args) for sample in samples] if ( - self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"] + self.args.advantage_estimator in ["grpo", "reinforce_plus_plus_baseline"] and self.args.rewards_normalization ): # group norm @@ -702,7 +702,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): mean = rewards.mean(dim=-1, keepdim=True) rewards = rewards - mean - if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization: + if self.args.advantage_estimator == "grpo" and self.args.grpo_std_normalization: std = rewards.std(dim=-1, keepdim=True) rewards = rewards / (std + 1e-6) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d8467049bb..7273e06b08 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -917,8 +917,32 @@ def add_algo_arguments(parser): ], default="grpo", help=( - "Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal " - "to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator." + "Advantage (credit-assignment) estimator. This axis is orthogonal to the surrogate " + "(--policy-loss) and the IS granularity (--is-level). Note: on-policy distillation (OPD) " + "is also orthogonal; use --opd-kl-coef > 0 to enable OPD on top of any estimator. " + "DEPRECATED values 'gspo' and 'cispo' are accepted for one release and remapped to " + "'--is-level sequence' and '--policy-loss cispo' respectively." + ), + ) + parser.add_argument( + "--policy-loss", + type=str, + choices=["ppo", "cispo"], + default="ppo", + help=( + "Policy-loss surrogate (the bounding rule that owns the importance weight), orthogonal " + "to --advantage-estimator and --is-level. 'ppo' is the clipped-ratio objective; 'cispo' " + "is the CISPO score-form objective (MiniMax-M1, https://arxiv.org/abs/2506.13585)." + ), + ) + parser.add_argument( + "--is-level", + type=str, + choices=["token", "sequence"], + default="token", + help=( + "Granularity of the importance-sampling ratio in the policy loss. 'token' is standard " + "per-token PPO; 'sequence' is the GSPO sequence-level ratio (https://arxiv.org/abs/2507.18071)." ), ) parser.add_argument( @@ -1739,6 +1763,24 @@ def _validate_update_weight_args(args) -> None: def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + # Deprecation shim: --advantage-estimator used to also select the surrogate ("cispo") and the + # IS granularity ("gspo"). These are now the orthogonal --policy-loss / --is-level axes. Remap the + # legacy values for one release (then drop them from the --advantage-estimator choices above). + _LEGACY_ADVANTAGE_ALIASES = { + "gspo": ("is_level", "sequence"), + "cispo": ("policy_loss", "cispo"), + } + if args.advantage_estimator in _LEGACY_ADVANTAGE_ALIASES: + field, value = _LEGACY_ADVANTAGE_ALIASES[args.advantage_estimator] + warnings.warn( + f"--advantage-estimator {args.advantage_estimator} is deprecated and will be removed; it now " + f"maps to '--advantage-estimator grpo --{field.replace('_', '-')} {value}'.", + FutureWarning, + stacklevel=2, + ) + setattr(args, field, value) + args.advantage_estimator = "grpo" + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1847,7 +1889,7 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip - if args.advantage_estimator == "cispo" and args.eps_clip < 1.0: + if args.policy_loss == "cispo" and args.eps_clip < 1.0: logger.warning( "CISPO is canonically single-sided, but --eps-clip=%s keeps the lower clip bound %s active. " "Set --eps-clip 1.0 (and tune --eps-clip-high, e.g. 4.0) for the canonical wide setting.", diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index db5c0bfd81..fd6801b314 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -248,6 +248,8 @@ def make_slime_validate_args(**overrides): save=None, kl_loss_coef=0, advantage_estimator="grpo", + policy_loss="ppo", + is_level="token", normalize_advantages=False, use_rollout_logprobs=False, use_tis=False, @@ -381,5 +383,50 @@ def test_update_weight_delta_rejects_unknown_transport(monkeypatch): module._validate_update_weight_args(args) +@pytest.mark.unit +def test_legacy_gspo_advantage_estimator_maps_to_sequence_is_level(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(advantage_estimator="gspo") + + with pytest.warns(FutureWarning, match="--advantage-estimator gspo is deprecated"): + module.slime_validate_args(args) + + assert args.advantage_estimator == "grpo" + assert args.is_level == "sequence" + assert args.policy_loss == "ppo" + + +@pytest.mark.unit +def test_legacy_cispo_advantage_estimator_maps_to_cispo_policy_loss(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + # eps_clip default 0.2 < 1.0 keeps the CISPO single-sided warning live; assert it still fires + # off the remapped --policy-loss axis (regression guard for the shim ordering). + args = make_slime_validate_args(advantage_estimator="cispo") + + with pytest.warns(FutureWarning, match="--advantage-estimator cispo is deprecated"): + module.slime_validate_args(args) + + assert args.advantage_estimator == "grpo" + assert args.policy_loss == "cispo" + assert args.is_level == "token" + + +@pytest.mark.unit +def test_non_legacy_advantage_estimator_is_untouched(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(advantage_estimator="grpo", policy_loss="cispo", is_level="sequence") + + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("error", FutureWarning) + module.slime_validate_args(args) + + # The orthogonal axes are passed through verbatim; no legacy remap. + assert args.advantage_estimator == "grpo" + assert args.policy_loss == "cispo" + assert args.is_level == "sequence" + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__]))