From 4bdb9e6cfdc9bb22eba07d6ede2824f089274ec2 Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Tue, 31 Mar 2026 00:54:50 +0000 Subject: [PATCH] Add NorMuon optimizer (row-wise adaptive normalization for Muon) NorMuon extends Muon with row-wise second moment tracking and adaptive normalization after Newton-Schulz orthogonalization, ensuring balanced neuron utilization with negligible memory overhead. Reference: Li et al., "NorMuon: Making Muon more efficient and scalable" (arxiv:2510.05491), 2025 --- optax/contrib/__init__.py | 3 + optax/contrib/_normuon.py | 341 +++++++++++++++++++++++ optax/contrib/_normuon_benchmark_test.py | 179 ++++++++++++ optax/contrib/_normuon_test.py | 148 ++++++++++ 4 files changed, 671 insertions(+) create mode 100644 optax/contrib/_normuon.py create mode 100644 optax/contrib/_normuon_benchmark_test.py create mode 100644 optax/contrib/_normuon_test.py diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..0c9b3ff07 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -54,6 +54,9 @@ from optax.contrib._muon import MuonDimensionNumbers from optax.contrib._muon import MuonState from optax.contrib._muon import scale_by_muon +from optax.contrib._normuon import normuon +from optax.contrib._normuon import NorMuonState +from optax.contrib._normuon import scale_by_normuon from optax.contrib._privacy import differentially_private_aggregate from optax.contrib._privacy import DifferentiallyPrivateAggregateState from optax.contrib._privacy import dpsgd diff --git a/optax/contrib/_normuon.py b/optax/contrib/_normuon.py new file mode 100644 index 000000000..9943f4f4b --- /dev/null +++ b/optax/contrib/_normuon.py @@ -0,0 +1,341 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""NorMuon optimizer.""" + +import math +from typing import Any, Callable, Literal, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp + +from optax._src import alias +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils +from optax.contrib._muon import _DEFAULT_NS_COEFFS +from optax.contrib._muon import _is_weight_dim_nums +from optax.contrib._muon import _NS_COEFFS_PRESET_DICT +from optax.contrib._muon import MuonDimensionNumbers +from optax.contrib._muon import orthogonalize_via_newton_schulz +from optax.contrib._muon import scale_by_shape +from optax.contrib._muon import WeightDimNumOrFn +from optax.transforms import _masking +import optax.tree + + +class NorMuonState(NamedTuple): + """State for the NorMuon algorithm.""" + count: jax.typing.ArrayLike # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + ns_coeffs: jax.typing.ArrayLike + + +def scale_by_normuon( + ns_coeffs: Union[ + tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, + jax.typing.ArrayLike], + tuple[ + tuple[ + jax.typing.ArrayLike, jax.typing.ArrayLike, + jax.typing.ArrayLike + ], + ..., + ], + ] = _DEFAULT_NS_COEFFS, + ns_steps: jax.typing.ArrayLike = 5, + beta: jax.typing.ArrayLike = 0.95, + beta2: jax.typing.ArrayLike = 0.95, + eps: jax.typing.ArrayLike = 1e-8, + mu_dtype: Optional[jax.typing.DTypeLike] = None, + *, + nesterov: bool = True, + preconditioning: Literal[ + 'frobenius', 'spectral', 'aol', 'schatten' + ] = 'frobenius', + weight_dimension_numbers: WeightDimNumOrFn | None = None, + normuon_scale: jax.typing.ArrayLike = 0.2, +) -> base.GradientTransformation: + r"""Rescale updates according to the NorMuon algorithm. + + NorMuon extends Muon with row-wise adaptive normalization after the + Newton-Schulz orthogonalization step. This balances neuron utilization + with negligible memory overhead compared to Muon. + + Args: + ns_coeffs: Coefficients for the Newton-Schulz method. + ns_steps: Number of Newton-Schulz iterations. + Ignored if ``ns_coeffs`` is a tuple of tuples. + beta: Decay rate for the exponentially weighted average of grads. + beta2: Decay rate for the row-wise second moment estimates. + eps: Term added to denominators to improve numerical stability. + mu_dtype: Data type of the momentum accumulator. + nesterov: Whether to use Nesterov momentum. + preconditioning: Which preconditioning method to use before NS iterations. + weight_dimension_numbers: An optional tree with the same structure as the + params of ``MuonDimensionNumbers``s, specifying how to reshape the + parameters before and after the orthogonalization OR a callable returning + such a tree. None implies that all parameters are 2D matrices. + normuon_scale: Adaptive learning rate coefficient (default 0.2). + + Returns: + A :class:`optax.GradientTransformation` object. + + References: + Li et al., `NorMuon: Making Muon more efficient and scalable + `_, 2025 + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = optax.tree.zeros_like(params, dtype=mu_dtype) + # nu stores row-wise second moments: shape (m,) for a (m, n) param. + nu = jax.tree.map(lambda x: jnp.zeros(x.shape[:-1], dtype=mu_dtype), + params) + ns_coeffs_ = jnp.asarray(ns_coeffs) + + if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3: + raise ValueError( + f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}' + ) + if ns_coeffs_.ndim == 2: + if ns_coeffs_.shape[0] > ns_steps: + raise ValueError(f'Not enough coeffs to perform {ns_steps} steps') + ns_coeffs_ = ns_coeffs_[-ns_steps:] + + return NorMuonState( + count=jnp.zeros([], jnp.int32), + mu=mu, + nu=nu, + ns_coeffs=ns_coeffs_, + ) + + def update_fn(updates, state, params=None): + del params + if callable(weight_dimension_numbers): + resolved_weight_dim_nums = weight_dimension_numbers(updates) + else: + resolved_weight_dim_nums = weight_dimension_numbers + + mu = optax.tree.update_moment(updates, state.mu, beta, 1) + count_inc = numerics.safe_increment(state.count) + if nesterov: + mu_hat = jax.tree.map( + lambda m, g: beta * m + (1 - beta) * g, + optax.tree.bias_correction( + mu, beta, numerics.safe_increment(count_inc) + ), + optax.tree.bias_correction(updates, beta, count_inc), + ) + else: + mu_hat = optax.tree.bias_correction(mu, beta, count_inc) + + # Apply Newton-Schulz orthogonalization. + ortho = jax.tree.map( + lambda x, dim_num: orthogonalize_via_newton_schulz( + x, state.ns_coeffs, ns_steps, preconditioning, eps, dim_num), + mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums) + + # Row-wise second moment tracking. + def _update_nu(o, nu_prev): + row_sq = jnp.mean(o ** 2, axis=-1) + return beta2 * nu_prev + (1 - beta2) * row_sq + + new_nu = jax.tree.map(_update_nu, ortho, state.nu) + + # Row-wise normalization and adaptive scaling (paper Algorithm 1). + def _normalize(o, nu_new): + o_hat = o / (jnp.sqrt(nu_new[..., None]) + eps) + m_n = math.prod(o.shape[-2:]) if o.ndim >= 2 else o.shape[-1] + frob = jnp.linalg.norm(o_hat, ord='fro') + scale = normuon_scale * jnp.sqrt(m_n) / (frob + eps) + return o_hat * scale + + new_updates = jax.tree.map(_normalize, ortho, new_nu) + + mu = optax.tree.cast(mu, mu_dtype) + return new_updates, NorMuonState( + count=count_inc, + mu=mu, + nu=new_nu, + ns_coeffs=state.ns_coeffs, + ) + + return base.GradientTransformation(init_fn, update_fn) + + +def normuon( + learning_rate: base.ScalarOrSchedule, + ns_coeffs: Union[ + tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, + jax.typing.ArrayLike], + tuple[ + tuple[ + jax.typing.ArrayLike, jax.typing.ArrayLike, + jax.typing.ArrayLike + ], + ..., + ], + str, + ] = _DEFAULT_NS_COEFFS, + ns_steps: jax.typing.ArrayLike = 5, + beta: jax.typing.ArrayLike = 0.95, + beta2: jax.typing.ArrayLike = 0.95, + eps: jax.typing.ArrayLike = 1e-8, + weight_decay: jax.typing.ArrayLike = 0.0, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]] + ] = None, + mu_dtype: Optional[jax.typing.DTypeLike] = None, + *, + nesterov: bool = True, + preconditioning: Literal[ + 'frobenius', 'spectral', 'aol', 'schatten' + ] = 'frobenius', + adam_b1: jax.typing.ArrayLike = 0.9, + adam_b2: jax.typing.ArrayLike = 0.999, + adam_eps_root: jax.typing.ArrayLike = 0.0, + adam_weight_decay: jax.typing.ArrayLike = 0.0, + adam_learning_rate: base.ScalarOrSchedule | None = None, + muon_weight_dimension_numbers: WeightDimNumOrFn | None = None, + normuon_scale: jax.typing.ArrayLike = 0.2, + consistent_rms: jax.typing.ArrayLike | None = None, +) -> base.GradientTransformation: + r"""NorMuon: Muon with row-wise adaptive normalization. + + NorMuon extends the Muon optimizer with row-wise adaptive normalization + applied after Newton-Schulz orthogonalization. This ensures balanced + neuron utilization with negligible memory overhead compared to Muon. + + Like Muon, NorMuon is only defined for 2D parameters (matrices). Non-2D + parameters are passed through an AdamW optimizer. + + Args: + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + ns_coeffs: Coefficients for the Newton-Schulz method (can be a string + indicator for a preset). Existing presets: ``muon``, ``dion``. + ns_steps: Number of Newton-Schulz iterations. + Ignored if ``ns_coeffs`` is a tuple of tuples. + beta: Decay rate for the exponentially weighted average of grads. + beta2: Decay rate for the row-wise second moment estimates. + eps: Term added to the denominator to improve numerical stability. + weight_decay: Strength of the weight decay regularization. + weight_decay_mask: A tree with same structure as (or a prefix of) the + params PyTree, or a Callable that returns such a pytree given the + params/updates. The leaves should be booleans, ``True`` for + leaves/subtrees you want to apply the weight decay to, and ``False`` + for those you want to skip. + mu_dtype: Data type of the momentum accumulator. + nesterov: Whether to use Nesterov momentum. + preconditioning: Which preconditioning method to use before NS iterations. + adam_b1: Exponential decay rate for Adam's first moment estimates. + adam_b2: Exponential decay rate for Adam's second moment estimates. + adam_eps_root: Epsilon to stabilize division in Adam, square root version. + adam_weight_decay: Weight decay factor for Adam. + adam_learning_rate: Auxiliary learning rate for the Adam optimizer. + If ``None``, the learning rate for Adam defaults to the same as NorMuon. + muon_weight_dimension_numbers: An optional tree of + ``MuonDimensionNumbers``s, specifying how to reshape the parameters for + orthogonalization. A ``None`` value indicates that the parameter is not + a NorMuon parameter and will be optimized with Adam. If not provided, + NorMuon is applied to all 2D parameters. + normuon_scale: Adaptive learning rate coefficient (default 0.2). + consistent_rms: An optional float to activate consistent RMS scaling. + + Returns: + The corresponding :class:`optax.GradientTransformation`. + + References: + Li et al., `NorMuon: Making Muon more efficient and scalable + `_, 2025 + """ + + if adam_learning_rate is None: + adam_learning_rate = learning_rate + + if isinstance(ns_coeffs, str): + if ns_coeffs not in _NS_COEFFS_PRESET_DICT: + raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}') + ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs] + else: + ns_coeffs_ = ns_coeffs + + # None at root indicates the default 2D rule. + if muon_weight_dimension_numbers is None: + param_labels = lambda params: jax.tree.map( + lambda x: 'normuon' if x.ndim == 2 else 'adam', params + ) + muon_weight_dimension_numbers = MuonDimensionNumbers() + else: + def param_labels(params): + dim_nums = (muon_weight_dimension_numbers(params) + if callable(muon_weight_dimension_numbers) + else muon_weight_dimension_numbers) + populate_subtree_ = lambda dim_num, x: jax.tree.map( + lambda y: 'normuon' if dim_num is not None else 'adam', x) + return jax.tree.map( + populate_subtree_, dim_nums, params, + is_leaf=lambda x: x is None or _is_weight_dim_nums(x)) + + def muon_weight_dim_nums_fn(params): + dim_nums = (muon_weight_dimension_numbers(params) + if callable(muon_weight_dimension_numbers) + else muon_weight_dimension_numbers) + mask = jax.tree.map( + lambda label: label == 'normuon', param_labels(params)) + is_leaf = lambda x: (x is None or _is_weight_dim_nums(x) + or isinstance(x, _masking.MaskedNode)) + populate_subtree_ = lambda dim_nums, submask: jax.tree.map( + lambda m: dim_nums if m else _masking.MaskedNode(), submask) + return jax.tree.map(populate_subtree_, dim_nums, mask, is_leaf=is_leaf) + + return combine.partition( + transforms={ + 'normuon': combine.chain( + scale_by_normuon( + ns_coeffs=ns_coeffs_, + ns_steps=ns_steps, + beta=beta, + beta2=beta2, + eps=eps, + mu_dtype=mu_dtype, + nesterov=nesterov, + preconditioning=preconditioning, + weight_dimension_numbers=muon_weight_dim_nums_fn, + normuon_scale=normuon_scale, + ), + scale_by_shape( + weight_dimension_numbers=muon_weight_dim_nums_fn, + consistent_rms=consistent_rms, + ), + transform.add_decayed_weights(weight_decay, weight_decay_mask), + transform.scale_by_learning_rate(learning_rate), + ), + 'adam': alias.adamw( + learning_rate=adam_learning_rate, + b1=adam_b1, + b2=adam_b2, + eps=eps, + eps_root=adam_eps_root, + weight_decay=adam_weight_decay, + mu_dtype=mu_dtype, + nesterov=nesterov, + ), + }, + param_labels=param_labels, + ) diff --git a/optax/contrib/_normuon_benchmark_test.py b/optax/contrib/_normuon_benchmark_test.py new file mode 100644 index 000000000..ebdae2421 --- /dev/null +++ b/optax/contrib/_normuon_benchmark_test.py @@ -0,0 +1,179 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark tests comparing NorMuon and Muon training convergence.""" + +from absl.testing import absltest +import jax +import jax.numpy as jnp +from optax._src import update +from optax.contrib import _muon +from optax.contrib import _normuon + + +def _make_mlp_params(key): + """Create parameters for a 2-layer MLP: 32 -> 64 -> 1.""" + k1, k2 = jax.random.split(key, 2) + return { + 'w1': jax.random.normal(k1, (32, 64)) * 0.1, + 'b1': jnp.zeros(64), + 'w2': jax.random.normal(k2, (64, 1)) * 0.1, + 'b2': jnp.zeros(1), + } + + +def _mlp_forward(params, x): + """Forward pass for a 2-layer MLP with tanh activation.""" + h = jnp.tanh(x @ params['w1'] + params['b1']) + return h @ params['w2'] + params['b2'] + + +def _make_data(key, batch_size=64, input_dim=32): + """Generate synthetic regression data.""" + k1, k2 = jax.random.split(key) + x = jax.random.normal(k1, (batch_size, input_dim)) + y = jnp.sum(x[:, :3], axis=-1, keepdims=True) + 0.1 * jax.random.normal( + k2, (batch_size, 1) + ) + return x, y + + +def _train(optimizer, params, x, y, steps=500): + """Train the MLP and return losses and final params.""" + state = optimizer.init(params) + losses = [] + + def loss_fn(p): + pred = _mlp_forward(p, x) + return jnp.mean((pred - y) ** 2) + + @jax.jit + def step_fn(params, state): + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = optimizer.update(grads, state, params) + new_params = update.apply_updates(params, updates) + return new_params, new_state, loss + + for _ in range(steps): + params, state, loss = step_fn(params, state) + losses.append(float(loss)) + return losses, params + + +class NorMuonBenchmarkTest(absltest.TestCase): + """Benchmark tests comparing NorMuon and Muon optimizers.""" + + def setUp(self): + super().setUp() + self.data_key = jax.random.key(42) + self.param_key = jax.random.key(0) + self.x, self.y = _make_data(self.data_key) + + def test_normuon_vs_muon_convergence(self): + """Both optimizers should converge; NorMuon within 5x of Muon.""" + params = _make_mlp_params(self.param_key) + + muon_opt = _muon.muon(learning_rate=0.01) + normuon_opt = _normuon.normuon(learning_rate=0.01) + + muon_losses, _ = _train(muon_opt, params, self.x, self.y, steps=500) + normuon_losses, _ = _train( + normuon_opt, params, self.x, self.y, steps=500 + ) + + # Print loss at key steps for visibility. + for name, losses in [('Muon', muon_losses), ('NorMuon', normuon_losses)]: + for step in [0, 100, 200, 300, 400, 499]: + print(f'{name} step {step}: loss={losses[step]:.6f}') + + # Both should converge: final loss < 10% of initial loss. + self.assertLess( + muon_losses[-1], + 0.1 * muon_losses[0], + 'Muon did not converge.', + ) + self.assertLess( + normuon_losses[-1], + 0.1 * normuon_losses[0], + 'NorMuon did not converge.', + ) + + # NorMuon should not be more than 2x worse than Muon. + self.assertLess( + normuon_losses[-1], + 2.0 * muon_losses[-1], + 'NorMuon final loss is more than 2x worse than Muon.', + ) + + def test_normuon_no_side_effects(self): + """NorMuon training should have no NaN/Inf and mostly decrease.""" + params = _make_mlp_params(self.param_key) + normuon_opt = _normuon.normuon(learning_rate=0.01) + losses, final_params = _train( + normuon_opt, params, self.x, self.y, steps=500 + ) + + # No NaN or Inf in any loss value. + for i, loss in enumerate(losses): + self.assertTrue( + jnp.isfinite(loss), f'Non-finite loss at step {i}: {loss}' + ) + + # Loss should be monotonically decreasing with tolerance for noise. + # Check that loss at every 50-step window is lower than the previous. + window = 50 + for i in range(window, len(losses), window): + avg_prev = sum(losses[i - window : i]) / window + avg_curr = sum(losses[i : min(i + window, len(losses))]) / max( + 1, min(window, len(losses) - i) + ) + self.assertLess( + avg_curr, + avg_prev * 1.5, + f'Loss not decreasing around step {i}: ' + f'avg_prev={avg_prev:.6f}, avg_curr={avg_curr:.6f}', + ) + + # All final parameters should be finite. + for name, p in final_params.items(): + self.assertTrue( + jnp.all(jnp.isfinite(p)), + f'Non-finite values in final param {name}', + ) + + def test_normuon_mixed_params_training(self): + """All params (2D weights and 1D biases) should be updated.""" + init_params = _make_mlp_params(self.param_key) + normuon_opt = _normuon.normuon(learning_rate=0.01) + _, final_params = _train( + normuon_opt, init_params, self.x, self.y, steps=500 + ) + + # Every parameter should have changed from its initial value. + for name, init_val in init_params.items(): + self.assertFalse( + jnp.allclose(init_val, final_params[name], atol=1e-8), + f'Parameter {name} was not updated during training.', + ) + + # All final parameters should remain finite. + for name, p in final_params.items(): + self.assertTrue( + jnp.all(jnp.isfinite(p)), + f'Non-finite values in final param {name}', + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax/contrib/_normuon_test.py b/optax/contrib/_normuon_test.py new file mode 100644 index 000000000..76470b27a --- /dev/null +++ b/optax/contrib/_normuon_test.py @@ -0,0 +1,148 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the NorMuon optimizer.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from optax._src import update +from optax.contrib import _normuon + + +class NorMuonTest(parameterized.TestCase): + + def test_basic_normuon(self): + """Test that normuon() runs and produces finite outputs.""" + key = jax.random.key(0) + params = {'w': jax.random.normal(key, (8, 6))} + opt = _normuon.normuon(learning_rate=1e-3) + state = opt.init(params) + grad = params + updates, new_state = opt.update(grad, state, params=params) + self.assertEqual(updates['w'].shape, (8, 6)) + self.assertTrue(jnp.all(jnp.isfinite(updates['w']))) + del new_state + + def test_mixed_params(self): + """Test that 2D params go through NorMuon and 1D through Adam.""" + key = jax.random.key(1) + params = { + 'w': jax.random.normal(key, (10, 5)), + 'b': jax.random.normal(key, (5,)), + } + opt = _normuon.normuon(learning_rate=1e-3) + state = opt.init(params) + grad = params + updates, _ = opt.update(grad, state, params=params) + self.assertEqual(updates['w'].shape, (10, 5)) + self.assertEqual(updates['b'].shape, (5,)) + self.assertTrue(jnp.all(jnp.isfinite(updates['w']))) + self.assertTrue(jnp.all(jnp.isfinite(updates['b']))) + + def test_convergence(self): + """Test that NorMuon can optimize a simple quadratic.""" + key = jax.random.key(2) + target = jax.random.normal(key, (4, 4)) + + def loss_fn(params): + return jnp.sum((params['w'] - target) ** 2) + + opt = _normuon.normuon(learning_rate=1e-2) + params = {'w': jnp.zeros((4, 4))} + state = opt.init(params) + + initial_loss = loss_fn(params) + for _ in range(200): + grad = jax.grad(loss_fn)(params) + updates, state = opt.update(grad, state, params=params) + params = update.apply_updates(params, updates) + + final_loss = loss_fn(params) + self.assertLess(final_loss, initial_loss * 0.5) + + @parameterized.product( + shape=[(6, 4), (8, 8), (3, 10)], + ) + def test_scale_by_normuon_direct(self, shape): + """Test scale_by_normuon directly on a 2D input.""" + key = jax.random.key(3) + params = jax.random.normal(key, shape) + opt = _normuon.scale_by_normuon() + state = opt.init(params) + grad = jax.random.normal(key, shape) + updates, new_state = opt.update(grad, state) + self.assertEqual(updates.shape, shape) + self.assertTrue(jnp.all(jnp.isfinite(updates))) + # Check that nu state has the right shape (rows only). + self.assertEqual(new_state.nu.shape, (shape[0],)) + + @parameterized.named_parameters( + ('small', 1e-7), + ('large', 1e7), + ) + def test_numerical_stability(self, scale): + """Test NorMuon with very small and very large inputs.""" + key = jax.random.key(4) + params = jax.random.normal(key, (8, 4)) * scale + opt = _normuon.normuon(learning_rate=1e-3) + state = opt.init(params) + grad = params + updates, _ = opt.update(grad, state, params=params) + self.assertTrue( + jnp.all(jnp.isfinite(updates['w'])) + if isinstance(updates, dict) + else jnp.all(jnp.isfinite(updates)) + ) + + def test_normuon_state_structure(self): + """Test that NorMuonState has the expected fields.""" + params = jnp.ones((4, 3)) + opt = _normuon.scale_by_normuon() + state = opt.init(params) + self.assertIsInstance(state, _normuon.NorMuonState) + self.assertEqual(state.count, 0) + self.assertEqual(state.mu.shape, (4, 3)) + self.assertEqual(state.nu.shape, (4,)) + + def test_nesterov_flag(self): + """Test that nesterov=True and False produce different momentum states.""" + key = jax.random.key(5) + params = jax.random.normal(key, (6, 4)) + + opt_nest = _normuon.scale_by_normuon(nesterov=True) + opt_no_nest = _normuon.scale_by_normuon(nesterov=False) + + state_nest = opt_nest.init(params) + state_no_nest = opt_no_nest.init(params) + + grad = jax.random.normal(jax.random.key(6), (6, 4)) + + # Run two steps so the momentum accumulates differently. + _, state_nest = opt_nest.update(grad, state_nest) + _, state_no_nest = opt_no_nest.update(grad, state_no_nest) + + grad2 = jax.random.normal(jax.random.key(7), (6, 4)) + updates_nest, _ = opt_nest.update(grad2, state_nest) + updates_no_nest, _ = opt_no_nest.update(grad2, state_no_nest) + + # The nu states should be identical (same ortho output), but the + # momentum paths differ, leading to different ortho inputs and thus + # different final updates. + self.assertFalse(jnp.allclose(updates_nest, updates_no_nest, atol=1e-6)) + + +if __name__ == '__main__': + absltest.main()