Skip to content

RFC: factor the policy loss into orthogonal axes (advantage × policy-loss × is-level × correction × regularizer) #1

Description

@EazyReal

RFC: factor the policy loss into orthogonal axes

Line numbers below are against main@3fd7927. Cross-repo PR references point at THUDM/slime.

Summary

policy_loss_function (loss.py:877) currently selects behaviour through one overloaded flag, --advantage-estimator, which today decides three independent things at once: how credit is assigned, which surrogate (bounding rule) is used, and at what granularity the importance ratio is taken. As we add algorithms (CISPO THUDM#2067, REINFORCE THUDM#2083, TIS THUDM#2084, aggregation modes THUDM#2090, OPD THUDM#2085) each lands as a new branch in that one if/elif, and combinations that are mathematically valid — e.g. CISPO bounding on a GAE advantage, or sequence-level ratio with a decoupled behaviour policy — are simply not expressible without a custom loss.

This RFC proposes the target decomposition the loss already half-implements, so that future algorithm work is configuration, not new branches. It does not ask for a big-bang rewrite: each axis below maps to an increment, several of which are already merged or in review. The keystone increment (#1 below) is implemented in the PR that accompanies this issue.

The decomposition

The per-token policy-gradient estimate is

g_t = w_t · A_t · ∇ log π_θ(a_t),     w_t = π_θ / μ     (μ = behaviour policy that generated the data)

Everything in policy_loss_function is one of four orthogonal axes plus one separable correction:

Axis What it decides Today Flag (proposed)
1. Advantage A_t credit assignment compute_advantages_and_returns (loss.py:657) --advantage-estimator {grpo, ppo(=GAE), reinforce_plus_plus, reinforce_plus_plus_baseline}
2. Surrogate the bounding rule that owns w_t (PPO gating-clip / CISPO detached-truncation / …), in ratio or score form compute_policy_loss / compute_cispo_loss, dispatched by advantage_estimator (loss.py:974-977) --policy-loss {ppo, cispo}
2b. Granularity token vs sequence ratio (token = PPO, sequence = GSPO) advantage_estimator == "gspo"compute_gspo_kl (loss.py:960) --is-level {token, sequence}
3. Aggregation per-token loss → scalar reducer + --loss-aggregation (THUDM#2090), --pg-loss-divisor (THUDM#2060) (already its own axis)
4. Regularizers additive terms: KL-to-ref, entropy, OPD KL via compute_approx_kl, OPD via advantage shaping (THUDM#2085) additive coefficients
+ Correction ρ_t = π_prox/μ θ-independent, detached off-policy reweight (TIS / IcePop) vanilla_tis_function, applied outside the surrogate at loss.py:847 (THUDM#2084) --use-tis / --custom-tis-function-path

Two facts make this the correct split, not just a tidier one:

  1. The surrogate owns the IS weight. w_t = π_θ/μ is the only correctness-mandated factor; PPO/CISPO/GSPO differ only in how they bound it (and "ratio vs score form" is not a free knob — gating-clip needs ratio form to zero clipped-token gradients; truncation needs score form to keep them). So bounding is one axis; the ratio is not a separate one.
  2. The correction is genuinely separable. ρ_t = π_prox/μ is θ-independent and detached, so multiplying it onto the per-token loss is gradient-identical whether done inside or outside the surrogate (∇(ρ·L) = ρ·∇L). slime already does it outside (loss.py:847) — which is the clean choice and means a new surrogate gets TIS for free.

Evidence that --advantage-estimator is overloaded

  • compute_advantages_and_returns (loss.py:716) maps grpo, gspo, and cispo to the same get_grpo_returns — i.e. gspo/cispo make zero difference to the advantage; they only change the surrogate/granularity. They are misfiled.
  • The surrogate is then re-selected from the same flag at loss.py:974-977 (cispo/else-PPO), and granularity again at loss.py:960 (gspo). One string, three decisions.
  • Result: --advantage-estimator cispo silently means "GRPO advantage + CISPO surrogate + token granularity", and there is no way to say "GAE advantage + CISPO surrogate".

Target state

forward → π_θ, entropy
  ├─ advantage         A_t              (--advantage-estimator)         [axis 1]
  ├─ log-ratio         lr = reduce(π_θ − π_ref) at --is-level          [axis 2b]
  ├─ surrogate         L_t = policy_loss(lr, A_t, log_probs, eps)      [axis 2]   ← --policy-loss
  ├─ correction        L_t *= ρ_t ;  reducer rebuilt with rej-mask     [+corr]    ← --use-tis (already outside)
  ├─ aggregate         reduce L_t                                      [axis 3]   ← --loss-aggregation
  └─ regularizers      loss = L + Σ coef·term (kl_to_ref, entropy, opd)[axis 4]

Three independent flags (--advantage-estimator, --policy-loss, --is-level) replace the one overloaded one; every valid combination becomes expressible, and adding a surrogate is one function + one registry line, composing for free with every advantage, granularity, correction, and regularizer.

How existing work maps onto the axes

PR Axis Status
THUDM#2067 CISPO 2 (surrogate) merged — currently filed under --advantage-estimator; the keystone PR re-homes it to --policy-loss
THUDM#2083 REINFORCE 2 (surrogate) open — joins --policy-loss when rebased onto the keystone
THUDM#2084 TIS hook correction open — already applied outside the surrogate ✅
THUDM#2090 loss-aggregation 3 (aggregation) open — already its own axis ✅
THUDM#2060 pg-loss-divisor 3 (aggregation) landed/closed
THUDM#2085 OPD temperature 4 (regularizer) open

The RFC asks for nothing to be reverted — it gives these increments a shared frame and stops the next one from deepening the overload.

Proposed increments (each standalone and behaviour-preserving)

  1. Split the surrogate and granularity axes out of --advantage-estimator (the keystone — implemented in the accompanying PR) — add --policy-loss {ppo,cispo} and --is-level {token,sequence}; route the surrogate/granularity dispatch off these instead of --advantage-estimator; keep a one-release deprecation shim mapping the legacy values (gspo--is-level sequence, cispo--policy-loss cispo). Dispatch-only: the loss math is untouched and legacy configs reproduce identically (unit tests cover the shim). REINFORCE (feat(rl): add REINFORCE advantage estimator THUDM/slime#2083) becomes a --policy-loss reinforce value when it rebases onto this.
  2. (optional) register_policy_loss registry — make the surrogate pluggable (decorator + load_function fallthrough, argparse choices from the registry), matching the existing --custom-*-function-path style.
  3. (optional) decoupled reference structure--reference-structure {coupled,decoupled} so the surrogate denominator can be the behaviour policy μ while a proximal anchor π_prox centres the trust region; composes with the existing TIS correction. (slime is coupled-only today: loss.py:908.)
  4. (optional) additive regularizer assembly — collect KL-to-ref / entropy / OPD as additive terms rather than scattered sites; keep OPD's current advantage-shaping form selectable (see fork below).

Open questions / forks

  1. Deprecation vs hard break. Re-homing gspo/cispo out of --advantage-estimator changes existing configs. The keystone PR uses a one-release deprecation shim (warn + auto-map); the alternative is a clean break with a migration note. Recommend the shim (implemented).
  2. OPD placement. OPD today shapes the advantage (gradient flows via the detached old-logprob ratio); moving it to an additive regularizer changes the gradient path and thus the objective. Keep advantage-shaping OPD as a selectable mode rather than silently changing the recipe.
  3. Granularity ceiling. A closed {token,sequence} enum can't express GMPO-style geometric ratios without a third value or an escape hatch. Accept the closed enum now, or add --custom-is-level?

Non-goals

  • No change to the Megatron/CP machinery, the reducer internals, value/SFT losses, or any numerics of existing presets (the keystone is dispatch-only; legacy presets reproduce identically, with OPD the one called-out exception if axis 4 is ever re-homed).
  • No new algorithm in this RFC — only the structure that makes adding them composition rather than branching.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions