Skip to content

Add option to normalize NS input in double precision#238

Open
skyw wants to merge 7 commits into
mainfrom
skyw/handle_tiny_values_in_muon
Open

Add option to normalize NS input in double precision#238
skyw wants to merge 7 commits into
mainfrom
skyw/handle_tiny_values_in_muon

Conversation

@skyw

@skyw skyw commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

"if" clause on calculated norm would trigger device to host sync, trying to do it on device will take multiple path, in which case none of them are better than just use double. A custom kernel can do it but is out of scope.

No zero division guard for fp64 path as we don't imagine a case that LLM training is done in fp64. Square of fp32 value won't underflow fp64.

skyw added 2 commits June 26, 2026 15:02
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a normalize_in_double flag to newton_schulz and distributed_normalize_p2 that computes the Frobenius norm in float64 before casting back to float32, preventing fp32 underflow when squaring very small parameter entries. It also replaces the previous scale-invariance regression test with a targeted underflow test.

  • The core fp64 accumulation path is correct for both the distributed and local cases; the clamp_min_ guard for the existing fp32 path is now properly in-place.
  • When normalize_in_double=True, no eps guard is applied after casting the fp64 norm back to float32, leaving a silent inf return path for inputs whose fp64 norm is below float32's minimum subnormal (~1.4 × 10⁻⁴⁵).
  • The new test exercises the underflow scenario but contains implicit device-to-host syncs in its setup assertions, and a float32/float64 dtype mismatch in the reference comparison that causes assert_close to fail on dtype before checking values (already noted in a prior review thread).

Confidence Score: 4/5

Safe to merge for normal LLM training workloads; the new fp64 norm path works correctly across the expected input range but has an unguarded cast that can silently return inf for artificially extreme inputs.

The main optimizer code correctly avoids device-to-host syncs (branching is on a Python bool, not a CUDA tensor), and the in-place clamp for the fp32 path is now correct. The concern is the normalize_in_double=True path: after computing the norm in fp64 and casting to float32, no eps guard is applied in either the distributed or local branch. For inputs whose Frobenius norm is below float32's minimum subnormal, the division silently produces inf. The test also has a dtype mismatch in its reference comparison that prevents it from actually verifying the normalization output, and its setup assertions trigger device-to-host syncs.

Both changed files warrant attention: muon_utils.py for the missing guard in the fp64 normalize path, and test_muon_utils.py for the device-to-host syncs and the float32/float64 dtype mismatch that silently invalidates the test's correctness assertion.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_utils.py Adds normalize_in_double option to accumulate the Frobenius norm in float64 for both distributed and local paths; the new clamp_min_ guard for the fp32 path is correct (in-place), but the fp64 path has no guard after casting the fp64 norm back to float32, which can produce inf for extremely small inputs.
tests/test_muon_utils.py Replaces scale-invariance regression test with a targeted fp64-underflow test; the test assertions on lines 127 and 129 trigger implicit device-to-host syncs; additionally the reference comparison at line 131 divides a float32 tensor by a float64 scalar, producing a float64 result that will fail assert_close's dtype check before exercising normalization correctness (flagged in a prior thread).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["newton_schulz(x, normalize_in_double)"] --> B{tp_group?}

    B -- Yes --> C["distributed_normalize_p2(x, eps, group, normalize_in_double)"]
    B -- No --> D{normalize_in_double?}

    C --> E{normalize_in_double?}
    E -- No --> F["x_sq = x (fp32)"]
    E -- Yes --> G["x_sq = x.double() (fp64)"]
    F --> H["x_sq_sum = (x_sq²).sum()"]
    G --> H
    H --> I["all_reduce(x_sq_sum)"]
    I --> J["norm = sqrt(x_sq_sum).to(fp32)"]
    J --> K{normalize_in_double?}
    K -- No --> L["norm.clamp_min_(eps) ✓"]
    K -- Yes --> M["⚠ no eps guard — norm may be 0 after cast"]
    L --> N["return x / norm"]
    M --> N

    D -- No --> O["F.normalize(x, eps=eps) ✓"]
    D -- Yes --> P["norm = vector_norm(x, dtype=fp64).to(fp32)"]
    P --> Q["⚠ no eps guard — norm may be 0 after cast"]
    Q --> R["X = x / norm"]
    O --> S["NS iterations"]
    R --> S
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["newton_schulz(x, normalize_in_double)"] --> B{tp_group?}

    B -- Yes --> C["distributed_normalize_p2(x, eps, group, normalize_in_double)"]
    B -- No --> D{normalize_in_double?}

    C --> E{normalize_in_double?}
    E -- No --> F["x_sq = x (fp32)"]
    E -- Yes --> G["x_sq = x.double() (fp64)"]
    F --> H["x_sq_sum = (x_sq²).sum()"]
    G --> H
    H --> I["all_reduce(x_sq_sum)"]
    I --> J["norm = sqrt(x_sq_sum).to(fp32)"]
    J --> K{normalize_in_double?}
    K -- No --> L["norm.clamp_min_(eps) ✓"]
    K -- Yes --> M["⚠ no eps guard — norm may be 0 after cast"]
    L --> N["return x / norm"]
    M --> N

    D -- No --> O["F.normalize(x, eps=eps) ✓"]
    D -- Yes --> P["norm = vector_norm(x, dtype=fp64).to(fp32)"]
    P --> Q["⚠ no eps guard — norm may be 0 after cast"]
    Q --> R["X = x / norm"]
    O --> S["NS iterations"]
    R --> S
Loading

Reviews (5): Last reviewed commit: "remove verbose logging" | Re-trigger Greptile

Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py Outdated
Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py Outdated
Comment thread tests/test_muon_utils.py Outdated
skyw added 2 commits June 26, 2026 16:09
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Comment thread emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from FDecaYed June 26, 2026 23:21
@skyw

skyw commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 854dfd9

@skyw skyw changed the title Handle tiny values in muon better Add option to normalize NS input in double precision Jun 26, 2026
@github-actions

github-actions Bot commented Jun 27, 2026

Copy link
Copy Markdown

Test Results

   81 files  ±0    151 suites  ±0   1m 45s ⏱️ +4s
1 163 tests  - 3  1 163 ✅  - 3  0 💤 ±0  0 ❌ ±0 
2 708 runs   - 6  2 708 ✅  - 6  0 💤 ±0  0 ❌ ±0 

Results for commit 3c9a69b. ± Comparison against base commit 46eda5a.

This pull request removes 4 and adds 1 tests. Note that renamed tests count towards both.
__main__.TestNewtonSchulz ‑ test_newtonschulz_small_eps0 (0.01)
__main__.TestNewtonSchulz ‑ test_newtonschulz_small_eps1 (1e-06)
__main__.TestNewtonSchulz ‑ test_newtonschulz_small_eps2 (1e-09)
__main__.TestNewtonSchulz ‑ test_newtonschulz_small_eps3 (1e-12)
__main__.TestNewtonSchulz ‑ test_preserve_values_with_underflowed_norm_in_fp64

♻️ This comment has been updated with latest results.

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 27, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 451176d

Comment thread tests/test_muon_utils.py
norm_ref = torch.linalg.vector_norm(x, dtype=torch.double)
assert norm_ref != 0
out = muon_utils.newton_schulz(x, steps=0, normalize_in_double=True)
torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 norm_ref is a 0-dimensional torch.float64 tensor (no keepdim, no .to()). Dividing the float32 tensor x by a float64 tensor follows PyTorch's type-promotion rules and produces a float64 result, while out is float32. torch.testing.assert_close checks dtype by default (check_dtype=True), so this call raises on the dtype mismatch before it ever checks numerical values — the normalization behaviour is never actually verified. Cast norm_ref to float32 before dividing to avoid the promotion.

Suggested change
torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6)
torch.testing.assert_close(x / norm_ref.to(x.dtype), out, atol=0, rtol=1e-6)

Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 27, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 3c9a69b

@codecov

codecov Bot commented Jun 27, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.66667% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...optimizers/orthogonalized_optimizers/muon_utils.py 91.66% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant