Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,21 @@ def get_coefficient_iterator(
return islice(base, steps)


def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup) -> torch.Tensor:
"""Normalize a tensor in a distributed way."""
x_sq_sum = (x * x).sum()
def distributed_normalize_p2(
x: torch.Tensor, eps: float, group: torch.distributed.ProcessGroup, normalize_in_double: bool = False
) -> torch.Tensor:
"""Normalize a tensor by its distributed Frobenius norm.

When ``normalize_in_double`` is set, the squared sum is accumulated in float64 so that tiny
entries do not underflow to zero when squared in float32.
"""
x_sq = x.double() if normalize_in_double else x
x_sq_sum = (x_sq * x_sq).sum()
torch.distributed.all_reduce(x_sq_sum, op=torch.distributed.ReduceOp.SUM, group=group)
return x / torch.sqrt(x_sq_sum).clamp_min(eps)
norm = torch.sqrt(x_sq_sum).to(x.dtype)
if not normalize_in_double:
norm.clamp_min_(eps)
return x / norm
Comment thread
skyw marked this conversation as resolved.


def newton_schulz(
Expand All @@ -145,6 +155,7 @@ def newton_schulz(
transpose: bool | None = None,
tp_group: torch.distributed.ProcessGroup | None = None,
use_syrk: bool = False,
normalize_in_double: bool = False,
) -> torch.Tensor:
"""Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.

Expand Down Expand Up @@ -177,6 +188,10 @@ def newton_schulz(
If None, will be determined based on the size of the tensor.
tp_group: The process group for communication if input is distributed.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
normalize_in_double: Whether to reduce the Frobenius norm in float64. This keeps the squared
sum out of float32 underflow for inputs with very small entries, at the cost of a float64
reduction. Without customized kernels, manually handle scaling without triggering a device to host
sync are usually more expensive than using double.

Returns:
The orthogonalization of x.
Expand All @@ -192,13 +207,17 @@ def newton_schulz(
if transpose:
x = x.mT

# Ensure spectral norm is at most 1.
# NOTE: ``eps`` is a divide-by-zero guard; it must stay well below any realistic ``||x||_F``
# yet remain fp32-safe when squared. See issue #229.
# Ensure spectral norm is at most 1 by normalizing with the Frobenius norm. Reducing in float64
# (``normalize_in_double``) keeps the squared sum out of float32 underflow for tiny-norm inputs.
if tp_group is not None:
X = distributed_normalize_p2(x, eps, tp_group)
X = distributed_normalize_p2(x, eps, tp_group, normalize_in_double)
else:
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
if not normalize_in_double:
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
else:
# eps is ignored when normalize in double.
norm = torch.linalg.vector_norm(x, dim=(-2, -1), keepdim=True, dtype=torch.float64).to(x.dtype)
X = x / norm

if coefficient_type in _COEFFICIENT_SETS:
coefficient_sets = _COEFFICIENT_SETS[coefficient_type]
Expand Down
30 changes: 8 additions & 22 deletions tests/test_muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,14 @@ def test_newtonschulz5_close_to_reference(self, dim1, dim2):
rtol=1e-7,
)

@parameterized.parameters(1e-2, 1e-6, 1e-9, 1e-12)
def test_newtonschulz_small_eps(self, scale):
"""Orthogonalization depends only on direction, so scaling the input must not change the output.

Regression test for issue #229: a too-large ``eps`` in the internal ``F.normalize`` divides
small-norm inputs by ``eps`` instead of their norm, silently degenerating the output. The
orthogonalized result for ``x`` and ``scale * x`` must match for any ``scale > 0``.
"""
x = torch.randn(256, 256, device=self.device, dtype=torch.float32)
x = x / x.norm() # unit Frobenius norm direction
ref = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
out = muon_utils.newton_schulz(scale * x, steps=5, coefficient_type="quintic")
torch.testing.assert_close(
out,
ref,
atol=1e-4,
rtol=1e-5,
msg=lambda m: (
f"newton_schulz not scale-invariant at input scale {scale}: "
f"||out||_F={out.norm().item():.4f} vs ||ref||_F={ref.norm().item():.4f}\n{m}"
),
)
def test_preserve_values_with_underflowed_norm_in_fp64(self):
scale = 1e-30
x = torch.randn(256, 256, device=self.device, dtype=torch.float32) * scale
assert torch.linalg.vector_norm(x) == 0 # should underflow
norm_ref = torch.linalg.vector_norm(x, dtype=torch.double)
assert norm_ref != 0
out = muon_utils.newton_schulz(x, steps=0, normalize_in_double=True)
torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 norm_ref is a 0-dimensional torch.float64 tensor (no keepdim, no .to()). Dividing the float32 tensor x by a float64 tensor follows PyTorch's type-promotion rules and produces a float64 result, while out is float32. torch.testing.assert_close checks dtype by default (check_dtype=True), so this call raises on the dtype mismatch before it ever checks numerical values — the normalization behaviour is never actually verified. Cast norm_ref to float32 before dividing to avoid the promotion.

Suggested change
torch.testing.assert_close(x / norm_ref, out, atol=0, rtol=1e-6)
torch.testing.assert_close(x / norm_ref.to(x.dtype), out, atol=0, rtol=1e-6)


@parameterized.parameters(
(2, 256, 256),
Expand Down
Loading