Skip to content

[megatron] fix: return 3-tuple under calculate_per_token_loss to fix MoE aux/z-loss grad blowup at CP>1#6836

Open
EricMarcus-ai wants to merge 2 commits into
verl-project:mainfrom
kaiko-ai:fix/megatron-moe-zloss-per-token-cp
Open

[megatron] fix: return 3-tuple under calculate_per_token_loss to fix MoE aux/z-loss grad blowup at CP>1#6836
EricMarcus-ai wants to merge 2 commits into
verl-project:mainfrom
kaiko-ai:fix/megatron-moe-zloss-per-token-cp

Conversation

@EricMarcus-ai

@EricMarcus-ai EricMarcus-ai commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes CP correctness bug in Megatron + Verl interaction.

When you run CP>1, Megatron-Bridge silently enables a per-token loss mode (calculate_per_token_loss=True), which expects a different normalization contract. That makes sense, the sequence is now distributed over ranks, so normalization has to be done carefully. In that mode, you are supposed to return the num_tokens for each micro-batch part (the non-masked tokens) to Megatron, so it can normalize correctly across ranks. verl did not do this, so Megatron saw 0 total tokens and skipped normalization → grad_norm explodes. Nothing in the stack warns you this is happening.

The grad_norm explosion only showed up in the z/aux losses, because verl doesn't normalize those itself, that's Megatron's job. The "normal" (policy) losses look fine in terms of grad_norm because verl normalizes them itself, but they are also silently slightly wrong for any heterogeneous CP>1 — they get normalized as if tokens were uniformly distributed across CP ranks, since verl never passed num_tokens.

Fix: when calculate_per_token_loss is enabled, the Megatron engine's loss path returns the 3-tuple (loss_sum, num_tokens, output) Megatron's pipeline schedule expects, using a globally all-reduced routed-token count so the reduction reproduces the intended per-token mean. This affects any CP>1 + Megatron run in verl.

Closes #6609.

Checklist Before Starting

Test

Regression test (tests/models/test_moe_zloss_per_token_loss.py): runs one PPO update through a tiny Qwen3-MoE Megatron engine twice,calculate_per_token_loss False vs True — and asserts grad_norm is invariant to the flag. Pre-fix, True blows grad_norm up by ~10⁴; post-fix the two match. Also verified in production runs.

API and Usage Example

No CLI/config signature changes. When the per-token regime is active (CP>1), it adds two fail-closed guards that raise a clear ValueError:

  • loss_agg_mode='seq-mean-token-mean' is rejected. The fix scales the loss by the global token count and lets Megatron's finalize_model_grads divide it back out — an exact cancel that preserves whatever normalization verl already applied. That holds for token-mean, seq-mean-token-sum, and seq-mean-token-sum-norm. It does not hold for seq-mean-token-mean: its inner per-sequence 1/n_s divides by a CP-local shard token count, which the outer cancel can't recover. Use one of the other three (default is token-mean).
  • use_remove_padding=False (BSHD) is rejected. Verl doesn't pass a padding_mask to the MoE router, so it normalizes the aux/z loss by the padding-inclusive token count (logits.shape[0] = B*S) while gradients are divided by the real token count — a padding-ratio mis-normalization. THD packs the padding, so logits.shape[0] already equals the actual token count. Use THD (use_remove_padding=True) or disable CP.

Design & Code Changes

verl/workers/engine/megatron/transformer_impl.py:

  • forward_backward_batch: when calculate_per_token_loss is set, all-reduce attention_mask.sum() over the DP group and change it as routed_num_tokens (the global router/attention-scope token count, distinct from the existing loss-token batch_num_tokens, which is why Guard 2 requires THD).
  • postprocess_micro_batch_func: in the per-token regime, return (loss * routed_num_tokens / dp_size, local_num_tokens, output) so Megatron's finalize_model_grads divides by the accumulated global token count and the aux/z ×num_tokens pre-scaling cancels. The legacy 2-tuple path is unchanged for calculate_per_token_loss=False.

Things we could remove from this PR / feedback

  • The regression test isn't strictly necessary. It exists to demonstrate and lock the invariant; it needs a GPU + a local tokenizer, so if we prefer not to carry a GPU unit test for this, we can drop it.
  • The explanatory comments in transformer_impl.py are a bit verbose. The normalization contract is subtle (silent flag flip, cross-rank cancellation), so we kept a thorough rationale inline, but we can remove if we prefer.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation — n/a (no user-facing config/API change).
  • Add unit or end-to-end test(s) — single-GPU regression test added; CP-correctness validated by the production run (a multi-GPU CP test would need a multi-rank fixture).
  • Once your PR is ready for CI, send a message in the ci-request channel.
  • If your PR is related to the recipe submodule; n/a.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where the MoE router z-loss gradient blows up under context parallelism by properly supporting Megatron's per-token loss regime. It calculates the global routed-token count in forward_backward_batch and updates postprocess_micro_batch_func to return the required 3-tuple (local_sum, local_num_tokens, output) when calculate_per_token_loss is enabled. A test is also added to verify the fix. The review feedback correctly identifies a potential KeyError in forward_backward_batch when attention_mask is missing from data, and suggests falling back to response_mask to prevent runtime failures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread verl/workers/engine/megatron/transformer_impl.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Megatron][MoE] grad_norm explosion under context parallel (CP>1): loss_func must return 3-tuple when calculate_per_token_loss=True

1 participant