Skip to content

[Dev] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes#5464

Open
zhongbozhu wants to merge 17 commits into
NVIDIA:devfrom
zhongbozhu:dev_fix_single_weight
Open

[Dev] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes#5464
zhongbozhu wants to merge 17 commits into
NVIDIA:devfrom
zhongbozhu:dev_fix_single_weight

Conversation

@zhongbozhu

@zhongbozhu zhongbozhu commented Jun 24, 2026

Copy link
Copy Markdown
Contributor
  • I, the PR author, have personally reviewed every line of this PR.

What does this PR do ?

Fix moe_single_grouped_weight with bf16, mxfp8, nvfp4 training with fp8/fp4 primary weight turned on or off.

Mirror PR to main: #5487

Unit tests with numerical checks passed, pending E2E validation. test_single_grouped_mxfp8_train_eval_train_matches_train_only is a newly introduced test targeting to test the reuse_grad_buff_for_mxfp8_param_ag rigorously, like adding checks for train-eval-train switches.

Unit test coverage matrix:

Precision Primary Weight Path Grad Accum Fusion Comparison Notes / Transformer Config
BF16 BF16 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp8=None
fp4=None
gradient_accumulation_fusion=False
BF16 BF16 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp8=None
fp4=None
gradient_accumulation_fusion=True
MXFP8 BF16 primary weight, MXFP8 compute Off single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=False
reuse_grad_buf_for_mxfp8_param_ag=False
gradient_accumulation_fusion=False
MXFP8 BF16 primary weight, MXFP8 compute On single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=False
reuse_grad_buf_for_mxfp8_param_ag=False
gradient_accumulation_fusion=True
MXFP8 MXFP8 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=True
reuse_grad_buf_for_mxfp8_param_ag=True
gradient_accumulation_fusion=False
MXFP8 MXFP8 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=True
reuse_grad_buf_for_mxfp8_param_ag=True
gradient_accumulation_fusion=True
NVFP4 BF16 primary weight, NVFP4 compute Off single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=False
gradient_accumulation_fusion=False
NVFP4 BF16 primary weight, NVFP4 compute On single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=False
gradient_accumulation_fusion=True
NVFP4 NVFP4 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=True
gradient_accumulation_fusion=False
NVFP4 NVFP4 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=True
gradient_accumulation_fusion=True

Env: 1 x gb200 node, 4 GPUs, the unit test only uses 2 parallel ranks.

Command:

torchrun --nproc_per_node=2 --log-dir /tmp/mcore-single-weight-ut --tee 0:3 --redirects 3 -m pytest -s -q tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_mxfp8_single_weight_torch_dist_checkpoint_matches_discrete_baseline[save-only-single]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_mxfp8_single_weight_torch_dist_checkpoint_matches_discrete_baseline[save-single-load-discrete]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_mxfp8_single_weight_torch_dist_checkpoint_matches_discrete_baseline[save-discrete-load-single]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_mxfp8_train_eval_train_matches_train_only
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-nvfp4]

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact @NVIDIA/mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment @NVIDIA/mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

@zhongbozhu zhongbozhu requested review from a team as code owners June 24, 2026 00:19
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@zhongbozhu zhongbozhu requested review from WanZzzzzz and kunlunl June 24, 2026 00:31
@kunlunl

kunlunl commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

/claude strict-review

Comment thread megatron/core/optimizer/distrib_optimizer.py Outdated
Comment thread megatron/core/fp8_utils.py
Comment thread megatron/core/transformer/moe/experts.py
Comment thread megatron/core/fp4_utils.py
@claude

claude Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Code Review Summary

CRITICAL: 0 | IMPORTANT: 2 | SUGGESTION: 3

Overall Assessment

This is a well-structured fix for moe_single_grouped_weight across BF16, MXFP8, and NVFP4 with thorough test coverage. The data-flow through quantized/non-quantized param paths in the DDP buffer and distributed optimizer is correct. The register_grouped_linear_params refactor properly addresses the root cause (TE overwriting DDP-managed parameters with fresh meta tensors). The torch.no_grad() additions are necessary to prevent autograd tracking on buffer management ops with tensor subclasses.

Risk level: Low-Medium. The changes are narrowly scoped to the GroupedTensor integration paths and gated behind moe_single_grouped_weight. The FSDP guard is a good safeguard. The numerical parity tests cover the full precision × param-gather × grad-accum-fusion matrix.

Key Findings

IMPORTANT — Unused _unwrap_parameter_data on DistributedOptimizer (distrib_optimizer.py:1118-1121)
Added as a @staticmethod but never called. Duplicates the function in fp8_utils.py. Should be removed (inline comment posted with suggestion block).

IMPORTANT — is_nvfp4tensor not updated to unwrap Parameters (fp4_utils.py:58-60)
The PR updates is_float8tensor and is_mxfp8tensor to handle torch.nn.Parameter-wrapped TE subclasses via _is_instance_or_param_data, but is_nvfp4tensor still uses plain isinstance. This inconsistency could misclassify a Parameter-wrapped NVFP4Tensor in _param_uses_quantized_storage. The fix is straightforward — this file already imports _is_instance_or_param_data indirectly through the fp8_utils imports added in this PR. (Couldn't post inline since these lines aren't in the diff.)

Suggested fix:

def is_nvfp4tensor(tensor: torch.Tensor) -> bool:
    """Check if a tensor is a Transformer Engine NVFP4Tensor."""
    return HAVE_TE_FP4_TENSOR_CLASS and _is_instance_or_param_data(tensor, FP4_TENSOR_CLASS)

Suggestions (posted inline)

  • copy_tensor_to_quantized_param: document that the plain copy_ fallback relies on TE's overridden method
  • register_grouped_linear_params: consider clearing stale "weight" in the per-index branch for symmetry
  • modify_grouped_nvfp4_rowwise_storage: add comment explaining why member views are refreshed eagerly (vs. lazily in the MXFP8 counterpart)

bucket.layerwise_params_list[local_rank]
).detach()
local_slot_view.copy_(flat_local_params)
with torch.no_grad():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this with torch.no_grad() needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, they are redundant I believe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to use torch.no_grad() when the mutation is intentional and should not affect gradients.
Looks like you removed this everywhere. Not sure if this matters, but:
I tried implementing this feature some time ago and I got below error in the past in mxfp8 reuse grad buffer case when doing some copying.
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation
Need to make sure unit test still pass after this change.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, there is a bug in E2E test not captured in UT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be resolved now

@zhongbozhu

Copy link
Copy Markdown
Contributor Author

/ok to test 509c7a6

@zhongbozhu

Copy link
Copy Markdown
Contributor Author

Note: GB200 unit test was added #5477 but not yet synced to dev

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
…ight

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu changed the title [Dev] Fix moe single grouped weight feature with fp8 fp4 primary weight support [Dev] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes Jun 28, 2026
Signed-off-by: zhongboz <zhongboz@nvidia.com>
@zhongbozhu

zhongbozhu commented Jun 28, 2026

Copy link
Copy Markdown
Contributor Author

E2E test with Qwen3.5 VL 35B-A3B SFT - branch dev_fix_single_weight

Before this PR, moe single weight will simply diverge. now it converges well.

Performance benefit comes form lower CPU overhead when quantizing to MXFP8 in distributed optimizer. Plus that CUDA Graph can be hard to open for multimodal SFT as of today.

Green plot (before this PR) had grad norm spikes because if we have reuse_grad_buff_for_mxfp8_param_ag, the training step right after eval doesn't clear the param_data buffer because the all-gather was already done in eval - so it got skipped, but unfortunately the zero buffer operation was also skipped.

image

@zhongbozhu

Copy link
Copy Markdown
Contributor Author

E2E performance benefit shown in Nsys - time spent in looping over moe weights in optimizer master weights and quantize to mxfp8, discrete weight vs. single weight

Discrete
image

Single
image

Signed-off-by: zhongboz <zhongboz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants