[Feature] Layer-wise MoE auxiliary loss (split finalize) and optional async router D2H offload#1528
Open
tina-wen wants to merge 12 commits intoInternLM:mainfrom
Open
[Feature] Layer-wise MoE auxiliary loss (split finalize) and optional async router D2H offload#1528tina-wen wants to merge 12 commits intoInternLM:mainfrom
tina-wen wants to merge 12 commits intoInternLM:mainfrom
Conversation
pppppM
reviewed
Mar 3, 2026
fb3ae25 to
95e9d62
Compare
Collaborator
|
@claude review |
Contributor
|
Claude: ## Summary Adds Domino EP support and layer-wise MoE balance loss computation for InternS1 Pro VL training. The layer-wise approach splits expert balance loss per-layer to reduce peak memory from accumulating all router outputs, with async CPU offloading for router tensors. IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES |
290dead to
2244c01
Compare
1968ecc to
298ebf4
Compare
HAOCHENYE
reviewed
Apr 27, 2026
Collaborator
HAOCHENYE
left a comment
There was a problem hiding this comment.
Just hold, the train engine refactor will optimizer
eb2469e to
aaac094
Compare
aaac094 to
3d1a6b6
Compare
e42ca82 to
a258bee
Compare
a258bee to
134bd9e
Compare
Eliminate duplicate per-layer tokens_per_expert storage / duplicate all_reduce between AuxLossContext and BalancingLossContext, and hoist non-pad index computation out of the per-layer accumulate path. - AuxLossContext owns the sole tokens_per_expert accumulator and produces both local and globally-reduced views once at finalize, passed into BalancingLossContext.finalize. - BalancingLossContext drops local_load_list; its accumulate now only stores routing_weights_sum (sum's backward does not save input, so the per-layer [non_pad, n_experts] activation is not pinned by this accumulator). - AuxLossContext.accumulate receives already-selected [non_pad, n_experts] tensors; mask / dim / unsqueeze(0) hack and the per-layer nonzero(mask) recomputation are gone. - num_experts_per_tok moves into AuxLossConfig / AuxLoss __init__ instead of being threaded through every accumulate call. - non_pad_token is read from nonpad_indices.numel() to avoid the per-step GPU->CPU sync via mask.sum().item(). - isinstance(list) branches in accumulate / finalize collapsed via _as_list. Verified bit-exact loss parity with the legacy BalancingLoss/ZLoss forward paths on local-average inputs.
…3 MoE configs Qwen3MoE235BA22Config and Qwen3_5_VLTextMoE35BA3BConfig redeclared router_async_offload and aux_loss_cfg with the same defaults already provided by the MoEConfig base. Remove the duplicates and the now-unused AuxLossConfig imports.
Remove the AuxLoss empty wrapper subclass and the AuxLossKwargs class along with the unused `device` field on AuxLossConfig / AuxLossKwargs. AuxLossConfig now builds an AuxLossContext directly, and the configured device defaults to the current device of the call site (no caller previously read this field).
Migrate Qwen3VLTextMoE._forward to the new aux_loss accumulate / finalize flow and remove the legacy components that the migration leaves dead: - BalancingLossContext.forward and ZLossContext.forward - The standalone BalancingLoss / ZLoss nn.Module classes (only referenced by an isinstance check in the profiler that never matched any module the model actually constructs) - MoE._select_non_pad_router_logits and MoE._cal_tokens_per_expert - The matching prober wrappers and before/after hook stubs After this change the new layer-wise accumulate / finalize is the single MoE auxiliary loss code path across Qwen3MoE and Qwen3VLTextMoE.
- _is_view_op: match the op base name (``func.overloadpacket.__name__``) exactly against the view-op set instead of substring containment, which over-matched cases like ``"select" in "select_backward"`` or ``"t" in "exp.default"`` and let non-view ops skip the wait path. - async_offload_to_cpu: only take the storage-level memcpy fast path when the source tensor is contiguous, owns its full storage, and starts at offset 0. The previous predicate (``storage().size() != numel()``) missed non-contiguous strided tensors with matching numel and would have copied garbage layout into the pinned CPU buffer. - Switch ``storage()`` to ``untyped_storage()`` for the storage-level copy to use the current API.
The two ``return_router_results`` branches in ``MoE._forward`` / ``MoE._forward_with_micro_batches`` / ``Qwen3VLTextMoE._forward`` reached a working state, but the prior ``# raise NotImplementedError`` placeholders and their truncated ``# TODO`` companions stayed behind and now actively mislead. Drop the placeholders and reword the TODOs.
…nsors per layer Adapted from Megatron-LM's MoEAuxLossAutoScaler. Each MoE layer's z-loss contribution is computed inline as a per-layer scalar with autograd attached, then injected into the main forward graph via a passthrough autograd.Function on a carrier hidden_states tensor. When backward through the main loss reaches the AuxLossScaler node it injects ones_like(z_loss_l) and triggers that layer's logsumexp backward inline, releasing its saved [non_pad, n_experts] activation in lockstep with the main backward instead of pinning all layers' saved tensors until a global finalize. - ZLossContext.accumulate now returns a per-layer scalar; the per-layer logsum / token_count lists are gone, replaced by a running detached scalar used only for logging. - AuxLossContext.accumulate takes a carrier hidden_states and returns the AuxLossScaler-augmented version. Caller must replace its handle. - MoE.forward / _forward_with_micro_batches / _forward (incl. MTP layers) and Qwen3VLTextMoE._forward thread hidden_states through aux_loss.accumulate. Domino path pins to MB0's stream so each per-layer z-loss receives ones exactly once. - The cross-rank token count needed by z_loss_global_average is computed once per forward in a small int64 all_reduce, hoisted out of the layer loop. Bit-exact loss values and gradient parity verified against the previous list-based path on a synthetic multi-layer microcase.
…outer_logits Previously every MoE forward populated output['router_logits'] / output['router_weights'] on every layer, then conditionally cleared only router_logits to None at the end. Two issues: - The per-layer dict population (and the optional D2H offload) ran even when the caller never read the result. - router_weights leaked: it was always returned even when router_logits was scrubbed. Hoist a single keep_router flag, gate dict initialization, per-layer push, and final detach/unsqueeze on it across MoE._forward, MoE._micro_batch_forward, and Qwen3VLTextMoE._forward. The micro-batch path additionally drops the unused router_weights_list bookkeeping (the weights side was never returned).
…pper The maybe_offload_tensor helper added two indirections (kwargs unpack + module boundary) for what is really one decision: should we async-D2H this router tensor or not? The flag and the offload stream both live on MoE, so the gating naturally belongs as a private MoE method. Inline it as MoE._maybe_offload_router and drop the wrapper from utils/router_offload.py.
Split bal loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Layer-wise split finalize for MoE auxiliary losses and optional async router D2H offload.
What changed
Implements layer-wise accumulation + finalize for balancing and z-loss
layer_moe_loss.py.Adds lazy async D2H offload for router tensors
router_offload.pyand integrates it inmoe.py.Benefits
Reduces peak GPU memory by offloading router logits/weights to pinned CPU while keeping tensor semantics.
Async offload is lazy (waits only when CPU data is actually needed), enabling overlap of D2H copy and GPU work to reduce wall-clock overhead.
Enabling router_async_offload=True lowers memory at the cost of possible host-device transfer latency.