[megatron] fix: return 3-tuple under calculate_per_token_loss to fix MoE aux/z-loss grad blowup at CP>1#6836
Conversation
…MoE aux/z-loss grad blowup at CP>1
There was a problem hiding this comment.
Code Review
This pull request addresses an issue where the MoE router z-loss gradient blows up under context parallelism by properly supporting Megatron's per-token loss regime. It calculates the global routed-token count in forward_backward_batch and updates postprocess_micro_batch_func to return the required 3-tuple (local_sum, local_num_tokens, output) when calculate_per_token_loss is enabled. A test is also added to verify the fix. The review feedback correctly identifies a potential KeyError in forward_backward_batch when attention_mask is missing from data, and suggests falling back to response_mask to prevent runtime failures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
What does this PR do?
Fixes CP correctness bug in Megatron + Verl interaction.
When you run CP>1, Megatron-Bridge silently enables a per-token loss mode (
calculate_per_token_loss=True), which expects a different normalization contract. That makes sense, the sequence is now distributed over ranks, so normalization has to be done carefully. In that mode, you are supposed to return thenum_tokensfor each micro-batch part (the non-masked tokens) to Megatron, so it can normalize correctly across ranks. verl did not do this, so Megatron saw 0 total tokens and skipped normalization → grad_norm explodes. Nothing in the stack warns you this is happening.The grad_norm explosion only showed up in the z/aux losses, because verl doesn't normalize those itself, that's Megatron's job. The "normal" (policy) losses look fine in terms of grad_norm because verl normalizes them itself, but they are also silently slightly wrong for any heterogeneous CP>1 — they get normalized as if tokens were uniformly distributed across CP ranks, since verl never passed
num_tokens.Fix: when
calculate_per_token_lossis enabled, the Megatron engine's loss path returns the 3-tuple(loss_sum, num_tokens, output)Megatron's pipeline schedule expects, using a globally all-reduced routed-token count so the reduction reproduces the intended per-token mean. This affects any CP>1 + Megatron run in verl.Closes #6609.
Checklist Before Starting
is:pr calculate_per_token_loss(none open) — the bug is tracked in [Megatron][MoE] grad_norm explosion under context parallel (CP>1): loss_func must return 3-tuple when calculate_per_token_loss=True #6609, which has no PR attached. The issue's suggested location (verl/workers/actor/megatron_actor.py) no longer exists after the worker→engine migration; this fix targets the live engine path.[{modules}] {type}: {description}→[megatron] fix: ...Test
Regression test (
tests/models/test_moe_zloss_per_token_loss.py): runs one PPO update through a tiny Qwen3-MoE Megatron engine twice,calculate_per_token_lossFalse vs True — and assertsgrad_normis invariant to the flag. Pre-fix,Trueblowsgrad_normup by ~10⁴; post-fix the two match. Also verified in production runs.API and Usage Example
No CLI/config signature changes. When the per-token regime is active (CP>1), it adds two fail-closed guards that raise a clear
ValueError:loss_agg_mode='seq-mean-token-mean'is rejected. The fix scales the loss by the global token count and lets Megatron'sfinalize_model_gradsdivide it back out — an exact cancel that preserves whatever normalization verl already applied. That holds fortoken-mean,seq-mean-token-sum, andseq-mean-token-sum-norm. It does not hold forseq-mean-token-mean: its inner per-sequence1/n_sdivides by a CP-local shard token count, which the outer cancel can't recover. Use one of the other three (default istoken-mean).use_remove_padding=False(BSHD) is rejected. Verl doesn't pass apadding_maskto the MoE router, so it normalizes the aux/z loss by the padding-inclusive token count (logits.shape[0] = B*S) while gradients are divided by the real token count — a padding-ratio mis-normalization. THD packs the padding, sologits.shape[0]already equals the actual token count. Use THD (use_remove_padding=True) or disable CP.Design & Code Changes
verl/workers/engine/megatron/transformer_impl.py:forward_backward_batch: whencalculate_per_token_lossis set, all-reduceattention_mask.sum()over the DP group and change it asrouted_num_tokens(the global router/attention-scope token count, distinct from the existing loss-tokenbatch_num_tokens, which is why Guard 2 requires THD).postprocess_micro_batch_func: in the per-token regime, return(loss * routed_num_tokens / dp_size, local_num_tokens, output)so Megatron'sfinalize_model_gradsdivides by the accumulated global token count and the aux/z×num_tokenspre-scaling cancels. The legacy 2-tuple path is unchanged forcalculate_per_token_loss=False.Things we could remove from this PR / feedback
transformer_impl.pyare a bit verbose. The normalization contract is subtle (silent flag flip, cross-rank cancellation), so we kept a thorough rationale inline, but we can remove if we prefer.Checklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel.recipesubmodule; n/a.