Skip to content

[Main] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes#5487

Open
zhongbozhu wants to merge 16 commits into
NVIDIA:mainfrom
zhongbozhu:main_fix_single_weight
Open

[Main] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes#5487
zhongbozhu wants to merge 16 commits into
NVIDIA:mainfrom
zhongbozhu:main_fix_single_weight

Conversation

@zhongbozhu

@zhongbozhu zhongbozhu commented Jun 24, 2026

Copy link
Copy Markdown
Contributor
  • I, the PR author, have personally reviewed every line of this PR.

What does this PR do ?

Mirrors: #5464

Unit tests with numerical checks passed, pending E2E validation.

test_single_grouped_mxfp8_train_eval_train_matches_train_only is a newly introduced test targeting to test the reuse_grad_buff_for_mxfp8_param_ag rigorously, like adding checks for train-eval-train switches.

Unit test coverage matrix:

Precision Primary Weight Path Grad Accum Fusion Comparison Notes / Transformer Config
BF16 BF16 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp8=None
fp4=None
gradient_accumulation_fusion=False
BF16 BF16 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp8=None
fp4=None
gradient_accumulation_fusion=True
MXFP8 BF16 primary weight, MXFP8 compute Off single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=False
reuse_grad_buf_for_mxfp8_param_ag=False
gradient_accumulation_fusion=False
MXFP8 BF16 primary weight, MXFP8 compute On single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=False
reuse_grad_buf_for_mxfp8_param_ag=False
gradient_accumulation_fusion=True
MXFP8 MXFP8 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=True
reuse_grad_buf_for_mxfp8_param_ag=True
gradient_accumulation_fusion=False
MXFP8 MXFP8 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp8="e4m3"
fp8_recipe="mxfp8"
fp8_param_gather=True
reuse_grad_buf_for_mxfp8_param_ag=True
gradient_accumulation_fusion=True
NVFP4 BF16 primary weight, NVFP4 compute Off single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=False
gradient_accumulation_fusion=False
NVFP4 BF16 primary weight, NVFP4 compute On single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=False
gradient_accumulation_fusion=True
NVFP4 NVFP4 primary weight Off single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=True
gradient_accumulation_fusion=False
NVFP4 NVFP4 primary weight On single grouped weight on compared with single grouped weight off bf16=True
fp4="e2m1"
fp4_recipe="nvfp4"
fp4_param_gather=True
gradient_accumulation_fusion=True

Env: 1 x gb200 node, 4 GPUs, the unit test only uses 2 parallel ranks.

Command:

torchrun --nproc_per_node=2 --log-dir /tmp/mcore-single-weight-ut --tee 0:3 --redirects 3 -m pytest -s -q tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_mxfp8_train_eval_train_matches_train_only
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_with_primary_param_gather[True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_without_primary_param_gather[True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[False-True-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-False-nvfp4]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-bf16]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-mxfp8]
[default0]:PASSED tests/unit_tests/transformer/moe/test_moe_single_grouped_weight_numerics.py::TestMoESingleGroupedWeightNumerics::test_single_grouped_weight_parity_module_grouped_linear[True-True-nvfp4]

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact @NVIDIA/mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment @NVIDIA/mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

@zhongbozhu zhongbozhu requested review from a team as code owners June 24, 2026 20:10
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions

Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@zhongbozhu

Copy link
Copy Markdown
Contributor Author

/ok to test 9df8f4e

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
…ight

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu force-pushed the main_fix_single_weight branch from 9df8f4e to 7973f73 Compare June 26, 2026 23:18
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: zhongboz <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu marked this pull request as ready for review June 28, 2026 05:35
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu changed the title [Main] Fix moe single grouped weight feature with fp8 fp4 primary weight support [Main] Numerical fix for moe single grouped weight with fp8 fp4 primary weight and grad norm spikes Jun 28, 2026
Signed-off-by: zhongboz <zhongboz@nvidia.com>
@zhongbozhu

zhongbozhu commented Jun 28, 2026

Copy link
Copy Markdown
Contributor Author

E2E test with Qwen3.5 VL 35B-A3B SFT - branch main_fix_single_weight

Before this PR, moe single weight will simply diverge. now it converges well.

Performance benefit comes form lower CPU overhead when quantizing to MXFP8 in distributed optimizer. Plus that CUDA Graph can be hard to open for multimodal SFT as of today.

Green plot (before this PR) had grad norm spikes because if we have reuse_grad_buff_for_mxfp8_param_ag, the training step right after eval doesn't clear the param_data buffer because the all-gather was already done in eval - so it got skipped, but unfortunately the zero buffer operation was also skipped.

image

@zhongbozhu

Copy link
Copy Markdown
Contributor Author

E2E performance benefit shown in Nsys - time spent in looping over moe weights in optimizer master weights and quantize to mxfp8, discrete weight vs. single weight

Discrete
image

Single
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants