fix(loss): first-class --pg-loss-divisor for Dr.GRPO#1328
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a dynamic, constant-divisor policy-gradient loss reducer for the Dr.GRPO algorithm, replacing a hardcoded divisor with a configurable --drgrpo-divisor argument and adding support for context parallelism under the "thd" layout. Feedback on the changes suggests raising a NotImplementedError when context parallelism is used with the unsupported "bshd" format to prevent silent correctness issues, and refactoring a unit test with monkeypatch to ensure it runs deterministically across all environments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| cp_size = get_parallel_state().cp.size | ||
|
|
||
| if cp_size == 1: | ||
| local_loss_masks = loss_masks | ||
| else: | ||
| local_loss_masks = slice_loss_masks_for_local_cp(loss_masks, total_lengths, response_lengths) |
There was a problem hiding this comment.
If cp_size > 1 and qkv_format == "bshd", the custom reducer will silently fall back to using "thd" slicing because slice_loss_masks_for_local_cp defaults to "thd". This can lead to silent correctness issues or crashes. Since max_seq_lens is not forwarded to the custom reducer factory, context parallelism with "bshd" format is fundamentally unsupported. We should raise a NotImplementedError to fail loud.
cp_size = get_parallel_state().cp.size
if cp_size == 1:
local_loss_masks = loss_masks
else:
qkv_format = getattr(resolved_args, "qkv_format", "thd")
if qkv_format == "bshd":
raise NotImplementedError(
"get_pg_loss_reducer (Dr.GRPO) does not support cp_size > 1 with qkv_format='bshd' "
"because max_seq_lens is not available at the reducer factory call site."
)
local_loss_masks = slice_loss_masks_for_local_cp(loss_masks, total_lengths, response_lengths)There was a problem hiding this comment.
Good catch, accepted. Confirmed the silent-mis-slice is real: this reducer's CP branch calls slice_loss_masks_for_local_cp(loss_masks, total_lengths, response_lengths) with qkv_format defaulting to 'thd' and max_seq_lens=None, while cp_utils.slice_with_cp derives the zigzag chunk_size from token_len for thd but from the padded max_seq_len for bshd — so a bshd+CP run would be chunked with the wrong offsets. Since the 4-positional call site in loss_hub/losses.py forwards no max_seq_lens, bshd CP can't be supported here yet (it's the documented follow-up). I've made it fail loud: cp_size>1 with args.qkv_format != 'thd' now raises NotImplementedError with a message pointing to qkv_format='thd', matching the fail-loud contract. Added test_bshd_under_cp_raises_not_implemented to pin it, and updated the module/factory docstrings.
| def test_missing_args_raises(): | ||
| """No args available (no Megatron globals) fails loud rather than guessing. | ||
|
|
||
| Asserts the documented CPU-only behaviour: when Megatron is not importable, | ||
| ``_resolve_args(None)`` returns ``None`` and ``_resolve_divisor`` raises | ||
| ``ValueError`` instead of silently falling back to a default divisor. Skip | ||
| where Megatron's global args happen to be importable, since ``get_args()`` | ||
| behaviour there is environment-dependent and exercised on GPU CI. | ||
| """ | ||
| try: | ||
| import megatron.training.global_vars # noqa: F401 | ||
|
|
||
| pytest.skip("Megatron is importable; missing-args fallback path is GPU-CI only.") | ||
| except ImportError: | ||
| pass | ||
|
|
||
| response_lengths = [2] | ||
| total_lengths = [3] | ||
| loss_masks = [torch.ones(2, dtype=torch.float32)] | ||
| with pytest.raises(ValueError): | ||
| get_pg_loss_reducer(total_lengths, response_lengths, loss_masks) |
There was a problem hiding this comment.
The test test_missing_args_raises currently skips itself if Megatron is importable, which reduces test coverage in environments where Megatron is installed (such as GPU CI). We can make this test deterministic and run in all environments by using monkeypatch to mock _resolve_args to return None.
| def test_missing_args_raises(): | |
| """No args available (no Megatron globals) fails loud rather than guessing. | |
| Asserts the documented CPU-only behaviour: when Megatron is not importable, | |
| ``_resolve_args(None)`` returns ``None`` and ``_resolve_divisor`` raises | |
| ``ValueError`` instead of silently falling back to a default divisor. Skip | |
| where Megatron's global args happen to be importable, since ``get_args()`` | |
| behaviour there is environment-dependent and exercised on GPU CI. | |
| """ | |
| try: | |
| import megatron.training.global_vars # noqa: F401 | |
| pytest.skip("Megatron is importable; missing-args fallback path is GPU-CI only.") | |
| except ImportError: | |
| pass | |
| response_lengths = [2] | |
| total_lengths = [3] | |
| loss_masks = [torch.ones(2, dtype=torch.float32)] | |
| with pytest.raises(ValueError): | |
| get_pg_loss_reducer(total_lengths, response_lengths, loss_masks) | |
| def test_missing_args_raises(monkeypatch): | |
| """No args available (no Megatron globals) fails loud rather than guessing.""" | |
| monkeypatch.setattr(custom_reducer, "_resolve_args", lambda args: None) | |
| response_lengths = [2] | |
| total_lengths = [3] | |
| loss_masks = [torch.ones(2, dtype=torch.float32)] | |
| with pytest.raises(ValueError, match="cannot read '--drgrpo-divisor'"): | |
| get_pg_loss_reducer(total_lengths, response_lengths, loss_masks) |
There was a problem hiding this comment.
Accepted. Replaced the import megatron ... pytest.skip(...) guard with monkeypatch.setattr(custom_reducer, '_resolve_args', lambda args: None), so the test now exercises the fail-loud path deterministically everywhere (including GPU CI, where it used to skip) rather than depending on whether Megatron globals happen to be importable. The intent is unchanged — it still asserts that an unresolvable-args reducer construction raises ValueError via _resolve_divisor. _resolve_args is a module-level symbol, monkeypatched the same way the suite already stubs get_parallel_state. 12 tests pass.
87805a0 to
7434ac8
Compare
The Dr.GRPO example reducer divided pg_loss by a module-level constant DIVISOR = 1000.0. Per the paper (arXiv:2503.20783, Eq. 2) that denominator should be the model's max context length (e.g. ~40960 for DeepSWE), so every realistic run silently rescaled the policy gradient, and the only way to change it was editing example source. The example also asserted cp_size == 1 and imported megatron.core.mpu, so it crashed under context parallelism and could not run on the FSDP backend. Promote the constant-divisor normalization into core loss aggregation instead of patching the example: - `get_sum_of_sample_mean` gains `divisor: float | None = None`; when set, each sample's masked token sum is divided by the constant instead of its active-token count. A constant denominator factors out of the per-sample sum, so the divisor mode reuses the existing CP-aware token-sum closure: CP-correct by construction, no extra denominator allreduce. - `policy_loss_function` builds the pg_loss reducer from `--pg-loss-divisor` when set; all other metrics keep the default reducer. `calculate_per_token_loss` handling is preserved: the token-sum path never applies the divisor (Megatron normalizes by token count itself). Both Megatron and FSDP backends share this path. - `--pg-loss-divisor` (float, default None = behavior unchanged) is validated at startup: non-positive or NaN values fail loud, as does combining it with --custom-pg-loss-reducer-function-path (which would silently ignore the divisor). - examples/DrGRPO shrinks to flag-based usage docs; the example reducer and its hardcoded constant are deleted. - tests/fast/backends/training_utils/test_pg_loss_divisor.py pins the contract: default identity, constant-mode values, per-token-loss exemption, CP rank-sum invariance, startup validation, and end-to-end wiring through policy_loss_function. Auto-registers in stage-a-cpu. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
7434ac8 to
a120dec
Compare
|
Superseded by #1350, which adds |
Add --loss-aggregation {sample_mean,prompt_mean,token_mean,constant}
(+ --loss-aggregation-divisor L for constant) selecting how pg_loss is
aggregated across a training step. The aggregation rescopes pg_loss only,
reusing the same seam as --custom-pg-loss-reducer-function-path: every
diagnostic metric (pg_clipfrac, ppo_kl, entropy_loss, kl_loss) keeps the
default sample-mean reducer so it stays interpretable and comparable.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 section 3.2):
- sample_mean (default, GRPO): per-rollout token-weighted mean. Rides the
sample_denoms seam in get_sum_of_sample_mean; byte-identical to the prior
default (the snapshot suite and a frozen-reducer oracle pin this).
- prompt_mean (DAPO): per-prompt-group token-weighted mean via step-level
sample_denoms grouped by Sample.group_index, plumbed like loss_masks so it
is CP- and DP-correct.
- token_mean: aliased onto --calculate-per-token-loss at validate time, so
the whole loss-scaling/reporting path stays per-token.
- constant (Dr.GRPO, arXiv:2503.20783): masked token sum / L via a new
constant_divisor branch in cp_utils. Subsumes radixark#1328's --pg-loss-divisor.
L is validated > 0 at startup only for constant. The custom reducer hook
still takes precedence over --loss-aggregation.
Supersedes radixark#1328.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add --loss-aggregation {sample_mean,prompt_mean,token_mean,constant}
(+ --loss-aggregation-divisor L for constant) selecting how pg_loss is
aggregated across a training step. The aggregation rescopes pg_loss only,
reusing the same seam as --custom-pg-loss-reducer-function-path: every
diagnostic metric (pg_clipfrac, ppo_kl, entropy_loss, kl_loss) keeps the
default sample-mean reducer so it stays interpretable and comparable.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 section 3.2):
- sample_mean (default, GRPO): per-rollout token-weighted mean. Rides the
sample_denoms seam in get_sum_of_sample_mean; byte-identical to the prior
default (the snapshot suite and a frozen-reducer oracle pin this).
- prompt_mean (DAPO): per-prompt-group token-weighted mean via step-level
sample_denoms grouped by Sample.group_index, plumbed like loss_masks so it
is CP- and DP-correct.
- token_mean: aliased onto --calculate-per-token-loss at validate time, so
the whole loss-scaling/reporting path stays per-token.
- constant (Dr.GRPO, arXiv:2503.20783): masked token sum / L via a new
constant_divisor branch in cp_utils. Subsumes radixark#1328's --pg-loss-divisor.
Fail-loud guards (each combination would otherwise silently mis-normalize):
- constant requires --loss-aggregation-divisor L > 0, validated at startup.
- constant and prompt_mean each reject --calculate-per-token-loss at startup:
the per-token path forces the outer Megatron average + metrics onto
/num_tokens, which would renormalize away the fixed L (constant) or the
per-group denominator (prompt_mean).
- prompt_mean requires a single training step per rollout: a step is a
contiguous slice, so with >1 step a prompt group can straddle the boundary
and be normalized against a partially-present group total. get_data_iterator
fails loud when num_steps_per_rollout > 1 (dynamic GBS already pins it to 1).
- prompt_mean fails loud if sample_denoms are missing (e.g. a custom convert
path dropped them) rather than degrading to a per-sample mean.
The custom reducer hook still takes precedence over --loss-aggregation.
The constant divisor L combines with the standard /global_batch_size step
average for an effective /(L * global_batch_size), the same outer structure as
every other mode; L sets the data-independent per-token scale.
Supersedes radixark#1328.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add --loss-aggregation {sample_mean,prompt_mean,token_mean,constant}
(+ --loss-aggregation-divisor L for constant) selecting how pg_loss is
aggregated across a training step. The aggregation rescopes pg_loss only,
reusing the same seam as --custom-pg-loss-reducer-function-path: every
diagnostic metric (pg_clipfrac, ppo_kl, entropy_loss, kl_loss) keeps the
default sample-mean reducer so it stays interpretable and comparable.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 section 3.2):
- sample_mean (default, GRPO): per-rollout token-weighted mean. Rides the
sample_denoms seam in get_sum_of_sample_mean; byte-identical to the prior
default (the snapshot suite and a frozen-reducer oracle pin this).
- prompt_mean (DAPO): per-prompt-group token-weighted mean via step-level
sample_denoms grouped by Sample.group_index, plumbed like loss_masks so it
is CP- and DP-correct.
- token_mean: the per-token global mean; the canonical spelling of the
legacy --calculate-per-token-loss flag (get_pg_loss_reducer returns the
default reducer for both, and that reducer is the per-token sum exactly when
--calculate-per-token-loss is set), so the two are one axis, not two.
- constant (Dr.GRPO, arXiv:2503.20783): masked token sum / L via a new
constant_divisor branch in cp_utils. Subsumes radixark#1328's --pg-loss-divisor.
--calculate-per-token-loss is kept as the backward-compatible alias for
token_mean (existing Megatron-style recipes rely on it). _validate_loss_-
aggregation_args reconciles the two spellings onto one axis at startup:
token_mean sets the flag (forward), and the flag with the default sample_mean
is relabeled to token_mean (reverse) -- no behavior change, since sample_mean
already used the per-token default reducer when the flag was set; this just
removes the silent sample_mean->token_mean surprise so the reported objective
is honest. Docs steer new recipes to --loss-aggregation token_mean.
Fail-loud guards (each combination would otherwise silently mis-normalize):
- constant requires --loss-aggregation-divisor L > 0, validated at startup.
- constant and prompt_mean each reject --calculate-per-token-loss at startup:
the per-token path forces the outer Megatron average + metrics onto
/num_tokens, which would renormalize away the fixed L (constant) or the
per-group denominator (prompt_mean).
- prompt_mean requires a single training step per rollout: a step is a
contiguous slice, so with >1 step a prompt group can straddle the boundary
and be normalized against a partially-present group total. get_data_iterator
fails loud when num_steps_per_rollout > 1 (dynamic GBS already pins it to 1).
- prompt_mean fails loud if sample_denoms are missing (e.g. a custom convert
path dropped them) rather than degrading to a per-sample mean.
The custom reducer hook still takes precedence over --loss-aggregation.
The constant divisor L combines with the standard /global_batch_size step
average for an effective /(L * global_batch_size), the same outer structure as
every other mode; L sets the data-independent per-token scale.
Supersedes radixark#1328.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add --loss-aggregation {sample_mean,prompt_mean,token_mean,constant}
(+ --loss-aggregation-divisor L for constant) selecting how pg_loss is
aggregated across a training step. The aggregation rescopes pg_loss only,
reusing the same seam as --custom-pg-loss-reducer-function-path: every
diagnostic metric (pg_clipfrac, ppo_kl, entropy_loss, kl_loss) keeps the
default sample-mean reducer so it stays interpretable and comparable.
Modes follow the ScaleRL taxonomy (arXiv:2510.13786 section 3.2):
- sample_mean (default, GRPO): per-rollout token-weighted mean. Rides the
sample_denoms seam in get_sum_of_sample_mean; byte-identical to the prior
default (the snapshot suite and a frozen-reducer oracle pin this).
- prompt_mean (DAPO): per-prompt-group token-weighted mean via step-level
sample_denoms grouped by Sample.group_index, plumbed like loss_masks so it
is CP- and DP-correct.
- token_mean: the per-token global mean; the canonical spelling of the
legacy --calculate-per-token-loss flag (get_pg_loss_reducer returns the
default reducer for both, and that reducer is the per-token sum exactly when
--calculate-per-token-loss is set), so the two are one axis, not two.
- constant (Dr.GRPO, arXiv:2503.20783): masked token sum / L via a new
constant_divisor branch in cp_utils. Subsumes radixark#1328's --pg-loss-divisor.
--calculate-per-token-loss is kept as the backward-compatible alias for
token_mean (existing Megatron-style recipes rely on it). _validate_loss_-
aggregation_args reconciles the two spellings onto one axis at startup:
token_mean sets the flag (forward), and the flag with the default sample_mean
is relabeled to token_mean (reverse) -- no behavior change, since sample_mean
already used the per-token default reducer when the flag was set; this just
removes the silent sample_mean->token_mean surprise so the reported objective
is honest. Docs steer new recipes to --loss-aggregation token_mean.
Fail-loud guards (each combination would otherwise silently mis-normalize):
- constant requires --loss-aggregation-divisor L > 0, validated at startup.
- constant and prompt_mean each reject --calculate-per-token-loss at startup:
the per-token path forces the outer Megatron average + metrics onto
/num_tokens, which would renormalize away the fixed L (constant) or the
per-group denominator (prompt_mean).
- prompt_mean requires a single training step per rollout: a step is a
contiguous slice, so with >1 step a prompt group can straddle the boundary
and be normalized against a partially-present group total. get_data_iterator
fails loud when num_steps_per_rollout > 1 (dynamic GBS already pins it to 1).
- prompt_mean fails loud if sample_denoms are missing (e.g. a custom convert
path dropped them) rather than degrading to a per-sample mean.
The custom reducer hook still takes precedence over --loss-aggregation.
The constant divisor L combines with the standard /global_batch_size step
average for an effective /(L * global_batch_size), the same outer structure as
every other mode; L sets the data-independent per-token scale.
Supersedes radixark#1328.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Problem
The Dr.GRPO example reducer in
examples/DrGRPO/custom_reducer.pydividedpg_lossby a module-level constantDIVISOR = 1000.0hardcoded in the source. Per the Dr.GRPO paper (arXiv:2503.20783, Eq. 2), that denominator is supposed to be the model's max context length (DeepSWE uses ~40960). With the hardcoded1000, every realistic run silently rescaled the policy gradient by the wrong constant, and the only way to change it was editing the example source. There was no validation, so the misconfiguration was undetectable — training just proceeded with a wrong loss scale. The reducer alsoassertedcp_size == 1(crashed under context parallelism) and importedmegatron.core.mpu(unusable on the FSDP backend).Before vs After
Same micro-batch — one sample,
x = [1.0, 2.0, 3.0], all-onesloss_mask, intended max context length4(toy value; real runs use e.g.40960):Before — the divisor is the source constant
1000.0; there is no flag:After — the divisor is a first-class option of the built-in reducer:
--pg-loss-divisor 4 # = the model's max context length in real runsWith the flag unset (the default), behavior is byte-identical to today: the per-sample active-token mean,
6 / 3 = 2.0here. Dr.GRPO deliberately replaces that data-dependent denominator with the constant (that is the whole point of Eq. 2). A misconfigured value now fails loud at startup instead of training silently:Fix
Promote the constant-divisor normalization into core loss aggregation instead of patching the example — it is one normalization option of the existing reducer, not a separate algorithm implementation:
get_sum_of_sample_mean(miles/backends/training_utils/cp_utils.py) gainsdivisor: float | None = None. When set, each sample's masked token sum is divided by the constant instead of its active-token count. A constant denominator factors out of the per-sample sum, so the divisor mode is simply the existing CP-aware token-sum closure scaled by1/divisor— no duplicated mask/chunking logic.policy_loss_function(miles/backends/training_utils/loss_hub/losses.py) builds the pg_loss reducer from--pg-loss-divisorwhen set. Onlypg_lossis affected; every other metric (pg_clipfrac,ppo_kl,entropy_loss, …) keeps the defaultsum_of_sample_mean, and under TIS/RS the divisor reducer uses the same modified masks as the existing custom-reducer path.--pg-loss-divisor(float, default None = current behavior, byte-identical) is added next to--custom-pg-loss-reducer-function-pathinmiles/utils/arguments.py.miles_validate_argsrejects non-positive (and NaN) values at startup, and rejects combining the flag with--custom-pg-loss-reducer-function-path(the custom hook fully replaces the reduction, so the divisor would be silently ignored — the exact failure mode this PR removes).calculate_per_token_losshandling is preserved: the token-sum path never applies the divisor (Megatron applies its own/num_tokens), avoiding double normalization.examples/DrGRPOshrinks to flag-based usage docs; the example reducer and its hardcoded constant are deleted, and the stale references indocs/user-guide/customization.md/examples/README.mdare updated.Why this is the right fix
cp_size == 1value:sum_over_ranks( sum(x_local * mask_local) / divisor ) == sum(x_full * mask_full) / divisor. The old example simply asserted CP away.loss_function→policy_loss_function, so the option works on both; the old example importedmegatron.core.mpuand was Megatron-only.get_sum_of_sample_meanreturns the exact same closures as before; the default loss path and all metrics are untouched.tests/fast/backends/training_utils/test_pg_loss_divisor.py(auto-registers in thestage-a-cpusuite by directory convention, CPU-only) pins the contract: default identity (divisor unset = per-sample active-token mean); constant-mode values (one shared denominator, masked-out tokens drop from the numerator only); per-token-loss exemption; CP rank-sum invariance (per-rank chunked contributions sum back to the single-rank value); fail-loud startup validation; and end-to-end wiring throughpolicy_loss_function(pg_loss scales byL/Dwhileppo_kl/pg_clipfracare unchanged). Each test was verified to fail on the pre-fix code or on targeted breakages (dropped wiring, multiply-instead-of-divide, divisor leaking into the per-token path, removed validation).