You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Line numbers below are against main@3fd7927. Cross-repo PR references point at THUDM/slime.
Summary
policy_loss_function (loss.py:877) currently selects behaviour through one overloaded flag, --advantage-estimator, which today decides three independent things at once: how credit is assigned, which surrogate (bounding rule) is used, and at what granularity the importance ratio is taken. As we add algorithms (CISPO THUDM#2067, REINFORCE THUDM#2083, TIS THUDM#2084, aggregation modes THUDM#2090, OPD THUDM#2085) each lands as a new branch in that one if/elif, and combinations that are mathematically valid — e.g. CISPO bounding on a GAE advantage, or sequence-level ratio with a decoupled behaviour policy — are simply not expressible without a custom loss.
This RFC proposes the target decomposition the loss already half-implements, so that future algorithm work is configuration, not new branches. It does not ask for a big-bang rewrite: each axis below maps to an increment, several of which are already merged or in review. The keystone increment (#1 below) is implemented in the PR that accompanies this issue.
vanilla_tis_function, applied outside the surrogate at loss.py:847 (THUDM#2084)
--use-tis / --custom-tis-function-path
Two facts make this the correct split, not just a tidier one:
The surrogate owns the IS weight.w_t = π_θ/μ is the only correctness-mandated factor; PPO/CISPO/GSPO differ only in how they bound it (and "ratio vs score form" is not a free knob — gating-clip needs ratio form to zero clipped-token gradients; truncation needs score form to keep them). So bounding is one axis; the ratio is not a separate one.
The correction is genuinely separable.ρ_t = π_prox/μ is θ-independent and detached, so multiplying it onto the per-token loss is gradient-identical whether done inside or outside the surrogate (∇(ρ·L) = ρ·∇L). slime already does it outside (loss.py:847) — which is the clean choice and means a new surrogate gets TIS for free.
Evidence that --advantage-estimator is overloaded
compute_advantages_and_returns (loss.py:716) maps grpo, gspo, and cispo to the sameget_grpo_returns — i.e. gspo/cispo make zero difference to the advantage; they only change the surrogate/granularity. They are misfiled.
The surrogate is then re-selected from the same flag at loss.py:974-977 (cispo/else-PPO), and granularity again at loss.py:960 (gspo). One string, three decisions.
Result: --advantage-estimator cispo silently means "GRPO advantage + CISPO surrogate + token granularity", and there is no way to say "GAE advantage + CISPO surrogate".
Three independent flags (--advantage-estimator, --policy-loss, --is-level) replace the one overloaded one; every valid combination becomes expressible, and adding a surrogate is one function + one registry line, composing for free with every advantage, granularity, correction, and regularizer.
The RFC asks for nothing to be reverted — it gives these increments a shared frame and stops the next one from deepening the overload.
Proposed increments (each standalone and behaviour-preserving)
Split the surrogate and granularity axes out of --advantage-estimator(the keystone — implemented in the accompanying PR) — add --policy-loss {ppo,cispo} and --is-level {token,sequence}; route the surrogate/granularity dispatch off these instead of --advantage-estimator; keep a one-release deprecation shim mapping the legacy values (gspo→--is-level sequence, cispo→--policy-loss cispo). Dispatch-only: the loss math is untouched and legacy configs reproduce identically (unit tests cover the shim). REINFORCE (feat(rl): add REINFORCE advantage estimator THUDM/slime#2083) becomes a --policy-loss reinforce value when it rebases onto this.
(optional) register_policy_loss registry — make the surrogate pluggable (decorator + load_function fallthrough, argparse choices from the registry), matching the existing --custom-*-function-path style.
(optional) decoupled reference structure — --reference-structure {coupled,decoupled} so the surrogate denominator can be the behaviour policy μ while a proximal anchor π_prox centres the trust region; composes with the existing TIS correction. (slime is coupled-only today: loss.py:908.)
(optional) additive regularizer assembly — collect KL-to-ref / entropy / OPD as additive terms rather than scattered sites; keep OPD's current advantage-shaping form selectable (see fork below).
Open questions / forks
Deprecation vs hard break. Re-homing gspo/cispo out of --advantage-estimator changes existing configs. The keystone PR uses a one-release deprecation shim (warn + auto-map); the alternative is a clean break with a migration note. Recommend the shim (implemented).
OPD placement. OPD today shapes the advantage (gradient flows via the detached old-logprob ratio); moving it to an additive regularizer changes the gradient path and thus the objective. Keep advantage-shaping OPD as a selectable mode rather than silently changing the recipe.
Granularity ceiling. A closed {token,sequence} enum can't express GMPO-style geometric ratios without a third value or an escape hatch. Accept the closed enum now, or add --custom-is-level?
Non-goals
No change to the Megatron/CP machinery, the reducer internals, value/SFT losses, or any numerics of existing presets (the keystone is dispatch-only; legacy presets reproduce identically, with OPD the one called-out exception if axis 4 is ever re-homed).
No new algorithm in this RFC — only the structure that makes adding them composition rather than branching.
RFC: factor the policy loss into orthogonal axes
Line numbers below are against
main@3fd7927. Cross-repo PR references point atTHUDM/slime.Summary
policy_loss_function(loss.py:877) currently selects behaviour through one overloaded flag,--advantage-estimator, which today decides three independent things at once: how credit is assigned, which surrogate (bounding rule) is used, and at what granularity the importance ratio is taken. As we add algorithms (CISPO THUDM#2067, REINFORCE THUDM#2083, TIS THUDM#2084, aggregation modes THUDM#2090, OPD THUDM#2085) each lands as a new branch in that oneif/elif, and combinations that are mathematically valid — e.g. CISPO bounding on a GAE advantage, or sequence-level ratio with a decoupled behaviour policy — are simply not expressible without a custom loss.This RFC proposes the target decomposition the loss already half-implements, so that future algorithm work is configuration, not new branches. It does not ask for a big-bang rewrite: each axis below maps to an increment, several of which are already merged or in review. The keystone increment (#1 below) is implemented in the PR that accompanies this issue.
The decomposition
The per-token policy-gradient estimate is
Everything in
policy_loss_functionis one of four orthogonal axes plus one separable correction:A_tcompute_advantages_and_returns(loss.py:657)--advantage-estimator {grpo, ppo(=GAE), reinforce_plus_plus, reinforce_plus_plus_baseline}w_t(PPO gating-clip / CISPO detached-truncation / …), in ratio or score formcompute_policy_loss/compute_cispo_loss, dispatched byadvantage_estimator(loss.py:974-977)--policy-loss {ppo, cispo}advantage_estimator == "gspo"→compute_gspo_kl(loss.py:960)--is-level {token, sequence}--loss-aggregation(THUDM#2090),--pg-loss-divisor(THUDM#2060)compute_approx_kl, OPD via advantage shaping (THUDM#2085)ρ_t = π_prox/μvanilla_tis_function, applied outside the surrogate atloss.py:847(THUDM#2084)--use-tis/--custom-tis-function-pathTwo facts make this the correct split, not just a tidier one:
w_t = π_θ/μis the only correctness-mandated factor; PPO/CISPO/GSPO differ only in how they bound it (and "ratio vs score form" is not a free knob — gating-clip needs ratio form to zero clipped-token gradients; truncation needs score form to keep them). So bounding is one axis; the ratio is not a separate one.ρ_t = π_prox/μis θ-independent and detached, so multiplying it onto the per-token loss is gradient-identical whether done inside or outside the surrogate (∇(ρ·L) = ρ·∇L). slime already does it outside (loss.py:847) — which is the clean choice and means a new surrogate gets TIS for free.Evidence that
--advantage-estimatoris overloadedcompute_advantages_and_returns(loss.py:716) mapsgrpo,gspo, andcispoto the sameget_grpo_returns— i.e.gspo/cispomake zero difference to the advantage; they only change the surrogate/granularity. They are misfiled.loss.py:974-977(cispo/else-PPO), and granularity again atloss.py:960(gspo). One string, three decisions.--advantage-estimator cisposilently means "GRPO advantage + CISPO surrogate + token granularity", and there is no way to say "GAE advantage + CISPO surrogate".Target state
Three independent flags (
--advantage-estimator,--policy-loss,--is-level) replace the one overloaded one; every valid combination becomes expressible, and adding a surrogate is one function + one registry line, composing for free with every advantage, granularity, correction, and regularizer.How existing work maps onto the axes
--advantage-estimator; the keystone PR re-homes it to--policy-loss--policy-losswhen rebased onto the keystoneThe RFC asks for nothing to be reverted — it gives these increments a shared frame and stops the next one from deepening the overload.
Proposed increments (each standalone and behaviour-preserving)
--advantage-estimator(the keystone — implemented in the accompanying PR) — add--policy-loss {ppo,cispo}and--is-level {token,sequence}; route the surrogate/granularity dispatch off these instead of--advantage-estimator; keep a one-release deprecation shim mapping the legacy values (gspo→--is-level sequence,cispo→--policy-loss cispo). Dispatch-only: the loss math is untouched and legacy configs reproduce identically (unit tests cover the shim). REINFORCE (feat(rl): add REINFORCE advantage estimator THUDM/slime#2083) becomes a--policy-loss reinforcevalue when it rebases onto this.register_policy_lossregistry — make the surrogate pluggable (decorator +load_functionfallthrough, argparse choices from the registry), matching the existing--custom-*-function-pathstyle.--reference-structure {coupled,decoupled}so the surrogate denominator can be the behaviour policyμwhile a proximal anchorπ_proxcentres the trust region; composes with the existing TIS correction. (slime is coupled-only today:loss.py:908.)Open questions / forks
gspo/cispoout of--advantage-estimatorchanges existing configs. The keystone PR uses a one-release deprecation shim (warn + auto-map); the alternative is a clean break with a migration note. Recommend the shim (implemented).{token,sequence}enum can't express GMPO-style geometric ratios without a third value or an escape hatch. Accept the closed enum now, or add--custom-is-level?Non-goals