Add option to normalize NS input in double precision#238
Conversation
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Greptile SummaryThis PR adds a
Confidence Score: 4/5Safe to merge for normal LLM training workloads; the new fp64 norm path works correctly across the expected input range but has an unguarded cast that can silently return inf for artificially extreme inputs. The main optimizer code correctly avoids device-to-host syncs (branching is on a Python bool, not a CUDA tensor), and the in-place clamp for the fp32 path is now correct. The concern is the normalize_in_double=True path: after computing the norm in fp64 and casting to float32, no eps guard is applied in either the distributed or local branch. For inputs whose Frobenius norm is below float32's minimum subnormal, the division silently produces inf. The test also has a dtype mismatch in its reference comparison that prevents it from actually verifying the normalization output, and its setup assertions trigger device-to-host syncs. Both changed files warrant attention: muon_utils.py for the missing guard in the fp64 normalize path, and test_muon_utils.py for the device-to-host syncs and the float32/float64 dtype mismatch that silently invalidates the test's correctness assertion. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["newton_schulz(x, normalize_in_double)"] --> B{tp_group?}
B -- Yes --> C["distributed_normalize_p2(x, eps, group, normalize_in_double)"]
B -- No --> D{normalize_in_double?}
C --> E{normalize_in_double?}
E -- No --> F["x_sq = x (fp32)"]
E -- Yes --> G["x_sq = x.double() (fp64)"]
F --> H["x_sq_sum = (x_sq²).sum()"]
G --> H
H --> I["all_reduce(x_sq_sum)"]
I --> J["norm = sqrt(x_sq_sum).to(fp32)"]
J --> K{normalize_in_double?}
K -- No --> L["norm.clamp_min_(eps) ✓"]
K -- Yes --> M["⚠ no eps guard — norm may be 0 after cast"]
L --> N["return x / norm"]
M --> N
D -- No --> O["F.normalize(x, eps=eps) ✓"]
D -- Yes --> P["norm = vector_norm(x, dtype=fp64).to(fp32)"]
P --> Q["⚠ no eps guard — norm may be 0 after cast"]
Q --> R["X = x / norm"]
O --> S["NS iterations"]
R --> S
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["newton_schulz(x, normalize_in_double)"] --> B{tp_group?}
B -- Yes --> C["distributed_normalize_p2(x, eps, group, normalize_in_double)"]
B -- No --> D{normalize_in_double?}
C --> E{normalize_in_double?}
E -- No --> F["x_sq = x (fp32)"]
E -- Yes --> G["x_sq = x.double() (fp64)"]
F --> H["x_sq_sum = (x_sq²).sum()"]
G --> H
H --> I["all_reduce(x_sq_sum)"]
I --> J["norm = sqrt(x_sq_sum).to(fp32)"]
J --> K{normalize_in_double?}
K -- No --> L["norm.clamp_min_(eps) ✓"]
K -- Yes --> M["⚠ no eps guard — norm may be 0 after cast"]
L --> N["return x / norm"]
M --> N
D -- No --> O["F.normalize(x, eps=eps) ✓"]
D -- Yes --> P["norm = vector_norm(x, dtype=fp64).to(fp32)"]
P --> Q["⚠ no eps guard — norm may be 0 after cast"]
Q --> R["X = x / norm"]
O --> S["NS iterations"]
R --> S
Reviews (5): Last reviewed commit: "remove verbose logging" | Re-trigger Greptile |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 854dfd9 |
Test Results 81 files ±0 151 suites ±0 1m 45s ⏱️ +4s Results for commit 3c9a69b. ± Comparison against base commit 46eda5a. This pull request removes 4 and adds 1 tests. Note that renamed tests count towards both.♻️ This comment has been updated with latest results. |
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 451176d |
| norm_ref = torch.linalg.vector_norm(x, dtype=torch.double) | ||
| assert norm_ref != 0 | ||
| out = muon_utils.newton_schulz(x, steps=0, normalize_in_double=True) | ||
| torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6) |
There was a problem hiding this comment.
norm_ref is a 0-dimensional torch.float64 tensor (no keepdim, no .to()). Dividing the float32 tensor x by a float64 tensor follows PyTorch's type-promotion rules and produces a float64 result, while out is float32. torch.testing.assert_close checks dtype by default (check_dtype=True), so this call raises on the dtype mismatch before it ever checks numerical values — the normalization behaviour is never actually verified. Cast norm_ref to float32 before dividing to avoid the promotion.
| torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6) | |
| torch.testing.assert_close(x / norm_ref.to(x.dtype), out, atol=0, rtol=1e-6) |
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 3c9a69b |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
"if" clause on calculated norm would trigger device to host sync, trying to do it on device will take multiple path, in which case none of them are better than just use double. A custom kernel can do it but is out of scope.
No zero division guard for fp64 path as we don't imagine a case that LLM training is done in fp64. Square of fp32 value won't underflow fp64.