Skip to content

[Feature] Layer-wise MoE auxiliary loss (split finalize) and optional async router D2H offload#1528

Open
tina-wen wants to merge 12 commits intoInternLM:mainfrom
tina-wen:split_bal_loss
Open

[Feature] Layer-wise MoE auxiliary loss (split finalize) and optional async router D2H offload#1528
tina-wen wants to merge 12 commits intoInternLM:mainfrom
tina-wen:split_bal_loss

Conversation

@tina-wen
Copy link
Copy Markdown
Contributor

@tina-wen tina-wen commented Mar 3, 2026

Summary

Layer-wise split finalize for MoE auxiliary losses and optional async router D2H offload.

What changed

Implements layer-wise accumulation + finalize for balancing and z-loss layer_moe_loss.py.
Adds lazy async D2H offload for router tensors router_offload.py and integrates it in moe.py.

Benefits

Reduces peak GPU memory by offloading router logits/weights to pinned CPU while keeping tensor semantics.
Async offload is lazy (waits only when CPU data is actually needed), enabling overlap of D2H copy and GPU work to reduce wall-clock overhead.
Enabling router_async_offload=True lowers memory at the cost of possible host-device transfer latency.

Comment thread xtuner/v1/model/moe/moe.py Outdated
@tina-wen tina-wen force-pushed the split_bal_loss branch 3 times, most recently from fb3ae25 to 95e9d62 Compare April 7, 2026 07:41
@HAOCHENYE
Copy link
Copy Markdown
Collaborator

@claude review

Comment thread xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/model/moe/moe.py Outdated
Comment thread xtuner/v1/model/moe/moe.py Outdated
Comment thread xtuner/v1/model/moe/moe.py Outdated
Comment thread xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py Outdated
Comment thread xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py Outdated
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 13, 2026

Claude: ## Summary

Adds Domino EP support and layer-wise MoE balance loss computation for InternS1 Pro VL training. The layer-wise approach splits expert balance loss per-layer to reduce peak memory from accumulating all router outputs, with async CPU offloading for router tensors.

Issues

Critical

  • moe.py:825non_pad_token=router_weights.shape[1] uses the wrong tensor when layer balancing is enabled. router_weights holds the last MoE layer's raw output (seq, num_experts), so shape[1] = num_experts, not the non-padding token count. Produces incorrect loss in non-distributed mode. Fix: non_pad_token=int(seq_ctx.mask.sum().item()).
  • modeling_qwen3_vl.py:141-146_prepare_llm_inputs return type annotated as MoEModelOutputs but actually returns torch.Tensor.

Warning

  • modeling_qwen3_vl.py:207-236 — Replaced seq_ctx.copy(...) with explicit SequenceContext(...) construction, silently dropping fields like device (defaults to "cpu"), block_table, image_grid_thw, etc. Use copy() to preserve all original fields.
  • moe.py:753 — Z-loss is silently disabled when layer_balancing_loss is enabled. If both are configured, users get no warning that z-loss is being skipped.
  • layer_moe_loss.py:220-225maybe_offload_tensor synchronously waits on the async D2H copy, so the async machinery provides no overlap benefit. Either defer the wait to maybe_wait_offload_tensor, or simplify to plain .cpu().

Nit

  • modeling_qwen3_vl.py:200forward signature exceeds 119-char line limit.
  • moe.py:460-464mask_list is a tensor, not a list; naming is inconsistent with cat_hidden_states / cat_position_ids convention.

Verdict

REQUEST_CHANGES

@tina-wen tina-wen force-pushed the split_bal_loss branch 4 times, most recently from 290dead to 2244c01 Compare April 23, 2026 12:39
@tina-wen tina-wen changed the title [Feature] Domino EP support and training optimizations for InternS1 Pro VL [Feature] Layer-wise MoE auxiliary loss (split finalize) and optional async router D2H offload Apr 23, 2026
@tina-wen tina-wen force-pushed the split_bal_loss branch 6 times, most recently from 1968ecc to 298ebf4 Compare April 27, 2026 07:31
Copy link
Copy Markdown
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just hold, the train engine refactor will optimizer

Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/aux_loss.py Outdated
Comment thread xtuner/v1/loss/moe_loss.py Outdated
Comment thread xtuner/v1/loss/moe_loss.py Outdated
HAOCHENYE and others added 10 commits May 6, 2026 04:36
Eliminate duplicate per-layer tokens_per_expert storage / duplicate all_reduce
between AuxLossContext and BalancingLossContext, and hoist non-pad index
computation out of the per-layer accumulate path.

- AuxLossContext owns the sole tokens_per_expert accumulator and produces
  both local and globally-reduced views once at finalize, passed into
  BalancingLossContext.finalize.
- BalancingLossContext drops local_load_list; its accumulate now only stores
  routing_weights_sum (sum's backward does not save input, so the per-layer
  [non_pad, n_experts] activation is not pinned by this accumulator).
- AuxLossContext.accumulate receives already-selected [non_pad, n_experts]
  tensors; mask / dim / unsqueeze(0) hack and the per-layer nonzero(mask)
  recomputation are gone.
- num_experts_per_tok moves into AuxLossConfig / AuxLoss __init__ instead of
  being threaded through every accumulate call.
- non_pad_token is read from nonpad_indices.numel() to avoid the per-step
  GPU->CPU sync via mask.sum().item().
- isinstance(list) branches in accumulate / finalize collapsed via _as_list.

Verified bit-exact loss parity with the legacy BalancingLoss/ZLoss forward
paths on local-average inputs.
…3 MoE configs

Qwen3MoE235BA22Config and Qwen3_5_VLTextMoE35BA3BConfig redeclared
router_async_offload and aux_loss_cfg with the same defaults already provided
by the MoEConfig base. Remove the duplicates and the now-unused AuxLossConfig
imports.
Remove the AuxLoss empty wrapper subclass and the AuxLossKwargs class along
with the unused `device` field on AuxLossConfig / AuxLossKwargs. AuxLossConfig
now builds an AuxLossContext directly, and the configured device defaults to
the current device of the call site (no caller previously read this field).
Migrate Qwen3VLTextMoE._forward to the new aux_loss accumulate / finalize
flow and remove the legacy components that the migration leaves dead:

- BalancingLossContext.forward and ZLossContext.forward
- The standalone BalancingLoss / ZLoss nn.Module classes (only referenced
  by an isinstance check in the profiler that never matched any module
  the model actually constructs)
- MoE._select_non_pad_router_logits and MoE._cal_tokens_per_expert
- The matching prober wrappers and before/after hook stubs

After this change the new layer-wise accumulate / finalize is the single
MoE auxiliary loss code path across Qwen3MoE and Qwen3VLTextMoE.
- _is_view_op: match the op base name (``func.overloadpacket.__name__``) exactly
  against the view-op set instead of substring containment, which over-matched
  cases like ``"select" in "select_backward"`` or ``"t" in "exp.default"`` and
  let non-view ops skip the wait path.
- async_offload_to_cpu: only take the storage-level memcpy fast path when the
  source tensor is contiguous, owns its full storage, and starts at offset 0.
  The previous predicate (``storage().size() != numel()``) missed
  non-contiguous strided tensors with matching numel and would have copied
  garbage layout into the pinned CPU buffer.
- Switch ``storage()`` to ``untyped_storage()`` for the storage-level copy to
  use the current API.
The two ``return_router_results`` branches in ``MoE._forward`` /
``MoE._forward_with_micro_batches`` / ``Qwen3VLTextMoE._forward`` reached a
working state, but the prior ``# raise NotImplementedError`` placeholders and
their truncated ``# TODO`` companions stayed behind and now actively mislead.
Drop the placeholders and reword the TODOs.
…nsors per layer

Adapted from Megatron-LM's MoEAuxLossAutoScaler. Each MoE layer's z-loss
contribution is computed inline as a per-layer scalar with autograd attached,
then injected into the main forward graph via a passthrough autograd.Function
on a carrier hidden_states tensor. When backward through the main loss reaches
the AuxLossScaler node it injects ones_like(z_loss_l) and triggers that
layer's logsumexp backward inline, releasing its saved [non_pad, n_experts]
activation in lockstep with the main backward instead of pinning all layers'
saved tensors until a global finalize.

- ZLossContext.accumulate now returns a per-layer scalar; the per-layer logsum
  / token_count lists are gone, replaced by a running detached scalar used
  only for logging.
- AuxLossContext.accumulate takes a carrier hidden_states and returns the
  AuxLossScaler-augmented version. Caller must replace its handle.
- MoE.forward / _forward_with_micro_batches / _forward (incl. MTP layers) and
  Qwen3VLTextMoE._forward thread hidden_states through aux_loss.accumulate.
  Domino path pins to MB0's stream so each per-layer z-loss receives ones
  exactly once.
- The cross-rank token count needed by z_loss_global_average is computed once
  per forward in a small int64 all_reduce, hoisted out of the layer loop.

Bit-exact loss values and gradient parity verified against the previous
list-based path on a synthetic multi-layer microcase.
…outer_logits

Previously every MoE forward populated output['router_logits'] / output['router_weights']
on every layer, then conditionally cleared only router_logits to None at the end. Two
issues:
- The per-layer dict population (and the optional D2H offload) ran even when the caller
  never read the result.
- router_weights leaked: it was always returned even when router_logits was scrubbed.

Hoist a single keep_router flag, gate dict initialization, per-layer push, and final
detach/unsqueeze on it across MoE._forward, MoE._micro_batch_forward, and
Qwen3VLTextMoE._forward. The micro-batch path additionally drops the unused
router_weights_list bookkeeping (the weights side was never returned).
…pper

The maybe_offload_tensor helper added two indirections (kwargs unpack + module
boundary) for what is really one decision: should we async-D2H this router
tensor or not? The flag and the offload stream both live on MoE, so the gating
naturally belongs as a private MoE method. Inline it as MoE._maybe_offload_router
and drop the wrapper from utils/router_offload.py.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants