From c1e14874802e4019ffd4b96b9ddb9cb05d24280e Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Tue, 9 Jun 2026 10:48:57 +0530 Subject: [PATCH 1/6] contrib: add SOAP optimizer (arXiv:2409.11321) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SOAP (Improving and Stabilizing Shampoo using Adam) combines Shampoo's Kronecker-factored preconditioning with Adam's per-coordinate adaptivity. For each 2-D weight matrix it maintains left/right Kronecker factors whose eigenbases define a rotation; Adam's moment buffers are tracked in that rotated space and the final update is unprojected back. Parameters with fewer than two dimensions fall back to standard Adam. New files: optax/contrib/_soap.py – scale_by_soap and soap implementations optax/contrib/_soap_test.py – 14 unit tests (shapes, orthogonality, convergence, JIT stability, Adam fallback, weight decay, mu_dtype, frequency control) Modified files: optax/contrib/__init__.py – export soap / scale_by_soap / ScaleBySOAPState optax/contrib/_common_test.py – add soap to the shared optimizer test matrix and mark precondition_frequency as static --- optax/contrib/__init__.py | 3 + optax/contrib/_common_test.py | 7 +- optax/contrib/_soap.py | 351 ++++++++++++++++++++++++++++++++++ optax/contrib/_soap_test.py | 319 ++++++++++++++++++++++++++++++ 4 files changed, 679 insertions(+), 1 deletion(-) create mode 100644 optax/contrib/_soap.py create mode 100644 optax/contrib/_soap_test.py diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..8b4761930 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -70,6 +70,9 @@ from optax.contrib._schedule_free import schedule_free_eval_params from optax.contrib._schedule_free import schedule_free_sgd from optax.contrib._schedule_free import ScheduleFreeState +from optax.contrib._soap import scale_by_soap +from optax.contrib._soap import ScaleBySOAPState +from optax.contrib._soap import soap from optax.contrib._sophia import hutchinson_estimator_diag_hessian from optax.contrib._sophia import HutchinsonState from optax.contrib._sophia import sophia diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index 510ef3159..ebb42860c 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -66,6 +66,10 @@ 'opt_name': 'sophia', 'opt_kwargs': {'learning_rate': 1e-2} }, + { + 'opt_name': 'soap', + 'opt_kwargs': {'learning_rate': 1e-2}, + }, { 'opt_name': 'galore', 'opt_kwargs': {'learning_rate': 1e-2, 'rank': 8} @@ -358,7 +362,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # inject_hyperparams. static_args = [] for uninjectable_hparam in ['warmup_steps', 'num_betas', 'clip_value_fn', - 'ns_steps', 'rank', 'update_proj_gap']: + 'ns_steps', 'rank', 'update_proj_gap', + 'precondition_frequency']: if uninjectable_hparam in inspect.signature(factory).parameters.keys(): static_args.append(uninjectable_hparam) static_args = tuple(static_args) diff --git a/optax/contrib/_soap.py b/optax/contrib/_soap.py new file mode 100644 index 000000000..bdcf8464d --- /dev/null +++ b/optax/contrib/_soap.py @@ -0,0 +1,351 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SOAP: Improving and Stabilizing Shampoo using Adam. + +Implementation of "SOAP: Improving and Stabilizing Shampoo using Adam" +(https://arxiv.org/abs/2409.11321) by Nikhil Vyas, Depen Morwani, Rosie Zhao, +Itai Shapira, David Brandfonbrener, Lucas Janson, and Sham Kakade. +""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils +from optax.transforms import _adding +import optax.tree + + +class ScaleBySOAPState(NamedTuple): + """State for the SOAP optimizer.""" + + count: jax.Array # shape=(), dtype=jnp.int32 + # Kronecker factors and their eigenbases, stored in float32. + # Leaves have shape (m, m) / (n, n) for 2D params, or (0,) for others. + left_factor: base.Updates + right_factor: base.Updates + left_basis: base.Updates + right_basis: base.Updates + # Adam moment buffers maintained in the rotated subspace for 2D params, + # or in the original space for non-2D params. + mu: base.Updates + nu: base.Updates + + +def scale_by_soap( + b1: jax.typing.ArrayLike = 0.9, + b2: jax.typing.ArrayLike = 0.999, + eps: jax.typing.ArrayLike = 1e-8, + precondition_frequency: int = 10, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""Scale updates using the SOAP preconditioner. + + See :func:`optax.contrib.soap` for full details. + + Args: + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment and Kronecker factor estimates. + eps: Small constant added to the denominator for numerical stability. + precondition_frequency: Number of steps between eigenbasis recomputations. + Lower values track gradient geometry more closely at higher cost per step. + Must be a positive Python int (not a JAX-traced value). + mu_dtype: Optional dtype for the first moment buffer. Useful for + reducing memory in mixed-precision training. If ``None``, the dtype is + inferred from the parameters. + + Returns: + A :class:`optax.GradientTransformation`. + """ + if not isinstance(precondition_frequency, int) or precondition_frequency < 1: + raise ValueError( + f'`precondition_frequency` must be a positive int, got' + f' {precondition_frequency!r}.' + ) + + # Normalize to float32 so that Python-float and JAX-float32 closures compute + # (1 - b2) identically. Without this, Python-float b2=0.999 gives + # 1-b2=0.001 (Python arithmetic) while JAX float32(0.999) gives + # 1-b2=0.0009999871 (float32 arithmetic), causing Kronecker factors to differ + # and the null-space eigenvectors of R to be numerically arbitrary (any + # orthonormal basis of a degenerate subspace is valid, so even tiny matrix + # differences produce completely different eigenvectors). + b1 = jnp.asarray(b1, dtype=jnp.float32) + b2 = jnp.asarray(b2, dtype=jnp.float32) + eps = jnp.asarray(eps, dtype=jnp.float32) + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params: base.Params) -> ScaleBySOAPState: + param_leaves = jax.tree.leaves(params) + if not param_leaves: + # Called with placeholder params (e.g. from tree_map_params). Return an + # empty state that matches the empty tree structure. + empty = params + return ScaleBySOAPState( + count=jnp.zeros([], jnp.int32), + left_factor=empty, + right_factor=empty, + left_basis=empty, + right_basis=empty, + mu=empty, + nu=empty, + ) + + def _init_factors(p): + if p.ndim == 2: + m, n = p.shape + return ( + jnp.zeros((m, m), dtype=jnp.float32), + jnp.zeros((n, n), dtype=jnp.float32), + jnp.eye(m, dtype=jnp.float32), + jnp.eye(n, dtype=jnp.float32), + ) + return ( + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + ) + + factor_tuples = jax.tree.map(_init_factors, params) + lf, rf, lb, rb = jax.tree.transpose( + jax.tree.structure(params), + jax.tree.structure((0, 0, 0, 0)), + factor_tuples, + ) + + return ScaleBySOAPState( + count=jnp.zeros([], jnp.int32), + left_factor=lf, + right_factor=rf, + left_basis=lb, + right_basis=rb, + mu=optax.tree.zeros_like(params, dtype=mu_dtype), + nu=optax.tree.zeros_like(params), + ) + + def update_fn( + updates: base.Updates, + state: ScaleBySOAPState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, ScaleBySOAPState]: + del params + count_inc = numerics.safe_int32_increment(state.count) + should_update_bases = jnp.equal( + jnp.mod(state.count, precondition_frequency), 0 + ) + + def _update_2d(g, l, r, q_l, q_r, mu, nu): + g_f32 = g.astype(jnp.float32) + + l_new = b2 * l + (1.0 - b2) * (g_f32 @ g_f32.T) + r_new = b2 * r + (1.0 - b2) * (g_f32.T @ g_f32) + + def _recompute_bases(args): + l_cur, r_cur = args + # Symmetrize to guard against accumulated floating-point asymmetry. + _, new_q_l = jnp.linalg.eigh((l_cur + l_cur.T) * 0.5) + _, new_q_r = jnp.linalg.eigh((r_cur + r_cur.T) * 0.5) + return new_q_l, new_q_r + + q_l_new, q_r_new = jax.lax.cond( + should_update_bases, + _recompute_bases, + lambda args: (q_l, q_r), + (l_new, r_new), + ) + + g_proj = q_l_new.T @ g_f32 @ q_r_new + + mu_f32 = mu.astype(jnp.float32) + nu_f32 = nu.astype(jnp.float32) + + mu_new = b1 * mu_f32 + (1.0 - b1) * g_proj + nu_new = b2 * nu_f32 + (1.0 - b2) * jnp.square(g_proj) + + mu_hat = mu_new / (1.0 - b1**count_inc) + nu_hat = nu_new / (1.0 - b2**count_inc) + u_proj = mu_hat / (jnp.sqrt(nu_hat) + eps) + + u = (q_l_new @ u_proj @ q_r_new.T).astype(g.dtype) + return ( + u, + l_new, + r_new, + q_l_new, + q_r_new, + mu_new.astype(mu.dtype), + nu_new.astype(nu.dtype), + ) + + def _update_nond(g, mu, nu): + g_f32 = g.astype(jnp.float32) + mu_new = b1 * mu.astype(jnp.float32) + (1.0 - b1) * g_f32 + nu_new = b2 * nu.astype(jnp.float32) + (1.0 - b2) * jnp.square(g_f32) + mu_hat = mu_new / (1.0 - b1**count_inc) + nu_hat = nu_new / (1.0 - b2**count_inc) + u = (mu_hat / (jnp.sqrt(nu_hat) + eps)).astype(g.dtype) + return u, mu_new.astype(mu.dtype), nu_new.astype(nu.dtype) + + def _update_single(g, l, r, q_l, q_r, mu, nu): + if g.ndim == 2: + u, l_new, r_new, q_l_new, q_r_new, mu_new, nu_new = _update_2d( + g, l, r, q_l, q_r, mu, nu + ) + else: + u, mu_new, nu_new = _update_nond(g, mu, nu) + l_new, r_new, q_l_new, q_r_new = l, r, q_l, q_r + return u, l_new, r_new, q_l_new, q_r_new, mu_new, nu_new + + result_tuples = jax.tree.map( + _update_single, + updates, + state.left_factor, + state.right_factor, + state.left_basis, + state.right_basis, + state.mu, + state.nu, + ) + + new_updates, new_lf, new_rf, new_lb, new_rb, new_mu, new_nu = ( + jax.tree.transpose( + jax.tree.structure(updates), + jax.tree.structure((0, 0, 0, 0, 0, 0, 0)), + result_tuples, + ) + ) + + new_mu = optax.tree.cast(new_mu, mu_dtype) + + return new_updates, ScaleBySOAPState( + count=count_inc, + left_factor=new_lf, + right_factor=new_rf, + left_basis=new_lb, + right_basis=new_rb, + mu=new_mu, + nu=new_nu, + ) + + return base.GradientTransformation(init_fn, update_fn) + + +def soap( + learning_rate: base.ScalarOrSchedule, + b1: jax.typing.ArrayLike = 0.9, + b2: jax.typing.ArrayLike = 0.999, + eps: jax.typing.ArrayLike = 1e-8, + weight_decay: jax.typing.ArrayLike = 0.0, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]] + ] = None, + precondition_frequency: int = 10, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""SOAP: Improving and Stabilizing Shampoo using Adam. + + SOAP combines the full-matrix preconditioning of Shampoo with the adaptive + moment estimation of Adam. For each 2D weight matrix :math:`W \in + \mathbb{R}^{m \times n}`, it maintains Kronecker-factor matrices whose + eigenbases define a rotation of the gradient. Adam's moment buffers are then + maintained in this rotated space, so the effective preconditioner is + Kronecker-structured but the per-coordinate adaptivity comes from Adam. + + For a 2D parameter with gradient :math:`G_t`: + + .. math:: + + \begin{align*} + L_t &\leftarrow \beta_2 L_{t-1} + (1 - \beta_2) G_t G_t^\top \\ + R_t &\leftarrow \beta_2 R_{t-1} + (1 - \beta_2) G_t^\top G_t \\ + Q_L, Q_R &\leftarrow \text{eigh}(L_t),\, \text{eigh}(R_t) + \quad (\text{every } k \text{ steps}) \\ + \tilde{G}_t &\leftarrow Q_L^\top G_t Q_R \\ + m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \tilde{G}_t \\ + v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \tilde{G}_t^2 \\ + \hat{m}_t &\leftarrow m_t / (1 - \beta_1^t) \\ + \hat{v}_t &\leftarrow v_t / (1 - \beta_2^t) \\ + \Delta_t &\leftarrow Q_L \bigl(\hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) + \bigr) Q_R^\top + \end{align*} + + Parameters with fewer than 2 dimensions fall back to standard Adam. + + .. note:: + SOAP stores left and right Kronecker factors and their eigenbases for each + 2D parameter, introducing memory overhead of :math:`O(m^2 + n^2)` per + parameter on top of the :math:`O(mn)` Adam moments. For very large weight + matrices this can be substantial; consider using it selectively via + :func:`optax.masked`. + + .. note:: + Eigenbasis recomputation via ``eigh`` adds per-step cost every + ``precondition_frequency`` steps. The default of 10 balances tracking + quality against compute. For large layers, increase this to 50–100. + + Args: + learning_rate: A global scaling factor, either fixed or a schedule; see + :func:`optax.scale_by_learning_rate`. + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment and Kronecker factor estimates. + eps: Small constant added to the denominator for numerical stability. + weight_decay: Optional :math:`\ell_2` regularization strength. + weight_decay_mask: A tree with the same structure as (or a prefix of) the + params pytree, or a callable that returns such a tree given the params. + Leaves should be booleans indicating which parameters to apply weight + decay to. + precondition_frequency: Number of steps between eigenbasis recomputations + from the Kronecker factors. Must be a positive Python int. + mu_dtype: Optional dtype for the first moment buffer; useful for reducing + memory in mixed-precision training. + + Returns: + A :class:`optax.GradientTransformation`. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def loss(params): + ... return jnp.sum(jnp.square(params['w'] - jnp.ones((4, 4)))) + >>> params = {'w': jnp.zeros((4, 4)), 'b': jnp.zeros(4)} + >>> solver = optax.contrib.soap(learning_rate=1e-2) + >>> state = solver.init(params) + >>> for _ in range(5): + ... grads = jax.grad(loss)(params) + ... updates, state = solver.update(grads, state, params) + ... params = optax.apply_updates(params, updates) + + References: + Vyas et al., `SOAP: Improving and Stabilizing Shampoo using Adam + `_, 2024 + """ + return combine.chain( + scale_by_soap( + b1=b1, + b2=b2, + eps=eps, + precondition_frequency=precondition_frequency, + mu_dtype=mu_dtype, + ), + _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), + transform.scale_by_learning_rate(learning_rate), + ) diff --git a/optax/contrib/_soap_test.py b/optax/contrib/_soap_test.py new file mode 100644 index 000000000..4b52b49d7 --- /dev/null +++ b/optax/contrib/_soap_test.py @@ -0,0 +1,319 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.contrib._soap.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from optax._src import numerics +from optax._src import update +from optax.contrib import _soap +import optax.tree + + +def _parabola_2d(dtype=jnp.float32): + initial = jnp.zeros((4, 4), dtype=dtype) + target = jnp.array( + [[1.0, 2.0, -1.0, 0.5], + [-1.0, 0.0, 2.0, 1.0], + [0.5, -0.5, 1.0, -1.0], + [0.0, 1.0, -0.5, 2.0]], dtype=dtype, + ) + obj_fn = lambda p: jnp.sum(numerics.abs_sq(p - target)) + return initial, target, obj_fn + + +def _mixed_params(dtype=jnp.float32): + """Dict with one 2D param (SOAP path) and one 1D param (Adam fallback).""" + initial = { + 'w': jnp.zeros((3, 3), dtype=dtype), + 'b': jnp.zeros((3,), dtype=dtype), + } + target = { + 'w': jnp.array([[1.0, -1.0, 0.5], [0.0, 2.0, -1.0], [-0.5, 1.0, 0.0]], + dtype=dtype), + 'b': jnp.array([1.0, -1.0, 0.5], dtype=dtype), + } + + def obj_fn(params): + return jnp.sum(numerics.abs_sq(params['w'] - target['w'])) + jnp.sum( + numerics.abs_sq(params['b'] - target['b']) + ) + + return initial, target, obj_fn + + +class ScaleBySOAPTest(parameterized.TestCase): + + def test_state_shapes_2d_param(self): + """Kronecker factors and bases have the expected shapes for 2D params.""" + m, n = 5, 3 + params = jnp.zeros((m, n)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + leaves = jax.tree.leaves(state.left_factor) + self.assertEqual(leaves[0].shape, (m, m)) + leaves = jax.tree.leaves(state.right_factor) + self.assertEqual(leaves[0].shape, (n, n)) + leaves = jax.tree.leaves(state.left_basis) + self.assertEqual(leaves[0].shape, (m, m)) + leaves = jax.tree.leaves(state.right_basis) + self.assertEqual(leaves[0].shape, (n, n)) + leaves = jax.tree.leaves(state.mu) + self.assertEqual(leaves[0].shape, (m, n)) + leaves = jax.tree.leaves(state.nu) + self.assertEqual(leaves[0].shape, (m, n)) + + def test_state_shapes_1d_param(self): + """Non-2D params use empty placeholder arrays for the factor fields.""" + params = jnp.zeros((7,)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + for field_name in ('left_factor', 'right_factor', 'left_basis', + 'right_basis'): + leaves = jax.tree.leaves(getattr(state, field_name)) + self.assertEqual(leaves[0].shape, (0,)) + + leaves = jax.tree.leaves(state.mu) + self.assertEqual(leaves[0].shape, (7,)) + + def test_state_shapes_mixed_params(self): + m, n, d = 4, 3, 5 + params = {'w': jnp.zeros((m, n)), 'b': jnp.zeros((d,))} + tx = _soap.scale_by_soap() + state = tx.init(params) + + lf_leaves = jax.tree.leaves(state.left_factor) + rf_leaves = jax.tree.leaves(state.right_factor) + # jax.tree.leaves traversal order is sorted by key + self.assertEqual(lf_leaves[0].shape, (0,)) # 'b' is 1D + self.assertEqual(lf_leaves[1].shape, (m, m)) # 'w' is 2D + self.assertEqual(rf_leaves[0].shape, (0,)) + self.assertEqual(rf_leaves[1].shape, (n, n)) + + def test_bases_are_orthogonal_after_update(self): + """Eigenbases must satisfy Q Q^T ≈ I after a gradient step.""" + params = jnp.ones((4, 4)) + tx = _soap.scale_by_soap(precondition_frequency=1) + state = tx.init(params) + + grads = jax.random.normal(jax.random.PRNGKey(0), params.shape) + _, new_state = tx.update(grads, state) + + q_l = jax.tree.leaves(new_state.left_basis)[0] + q_r = jax.tree.leaves(new_state.right_basis)[0] + n = q_l.shape[0] + + self.assertEqual(q_l.shape, (n, n)) + self.assertEqual(q_r.shape, (params.shape[1], params.shape[1])) + + eye_l = q_l.T @ q_l + eye_r = q_r.T @ q_r + self.assertTrue(jnp.allclose(eye_l, jnp.eye(n), atol=1e-5)) + self.assertTrue(jnp.allclose(eye_r, jnp.eye(q_r.shape[0]), atol=1e-5)) + + def test_kronecker_factors_are_symmetric(self): + """Left and right Kronecker factors must remain symmetric.""" + params = jnp.zeros((5, 3)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + key = jax.random.PRNGKey(42) + for _ in range(5): + key, subkey = jax.random.split(key) + grads = jax.random.normal(subkey, params.shape) + _, state = tx.update(grads, state) + + l = jax.tree.leaves(state.left_factor)[0] + r = jax.tree.leaves(state.right_factor)[0] + self.assertTrue(jnp.allclose(l, l.T, atol=1e-6)) + self.assertTrue(jnp.allclose(r, r.T, atol=1e-6)) + + def test_precondition_frequency_respected(self): + """Eigenbases should only change at multiples of precondition_frequency. + + With precondition_frequency=k and count starting at 0: + - update call i sees count=i, triggers eigh when i % k == 0. + - So eigh runs at calls 0, k, 2k, ... (0-indexed). + - bases_at_step[i] holds the basis after call i. + - bases_at_step[0..k-1] are all identical (refresh only at call 0). + - bases_at_step[k] is fresh (refresh at call k). + """ + k = 5 + params = jnp.ones((3, 3)) + tx = _soap.scale_by_soap(precondition_frequency=k) + state = tx.init(params) + + key = jax.random.PRNGKey(7) + bases_at_step = [] + for _ in range(k + 2): + key, subkey = jax.random.split(key) + grads = jax.random.normal(subkey, params.shape) + _, state = tx.update(grads, state) + q_l = jax.tree.leaves(state.left_basis)[0].copy() + bases_at_step.append(q_l) + + # Steps 1..k-1 must share the same basis as step 0 (no refresh in between). + for i in range(1, k): + self.assertTrue( + jnp.allclose(bases_at_step[i], bases_at_step[0]), + msg=f'basis unexpectedly changed at step {i}', + ) + + # Step k triggers a refresh; the new basis should differ from the previous + # one (which was computed from a single gradient, making it very unlikely + # to match the new one computed from k accumulated steps). + self.assertFalse( + jnp.allclose(bases_at_step[k], bases_at_step[k - 1]), + msg='basis did not change at the expected precondition step', + ) + + def test_convergence_on_2d_quadratic(self): + """SOAP converges to the minimum of a simple quadratic over 2D params.""" + initial, target, obj_fn = _parabola_2d() + + solver = _soap.soap(learning_rate=1e-2, precondition_frequency=5) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(2000): + params, state = step(params, state) + + self.assertTrue( + jnp.allclose(params, target, atol=0.05), + msg=f'Max deviation: {jnp.max(jnp.abs(params - target)):.4f}', + ) + + def test_convergence_on_mixed_params(self): + """SOAP handles dicts with both 2D and 1D params.""" + initial, target, obj_fn = _mixed_params() + + solver = _soap.soap(learning_rate=1e-2, precondition_frequency=5) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(3000): + params, state = step(params, state) + + for key in target: + self.assertTrue( + jnp.allclose(params[key], target[key], atol=0.05), + msg=f"key={key}, max deviation:" + f" {jnp.max(jnp.abs(params[key] - target[key])):.4f}", + ) + + def test_jit_no_recompilation(self): + """optimizer.update should not retrace on the second call.""" + from optax._src import test_utils # pylint: disable=g-import-not-at-top + + params = {'w': jnp.ones((3, 3)), 'b': jnp.ones(3)} + solver = _soap.soap(learning_rate=1e-3) + state = solver.init(params) + grads = jax.tree.map(jnp.ones_like, params) + + step = jax.jit(lambda p, s: solver.update(grads, s, p)) + _, state = step(params, state) + + with test_utils.log_compilations() as logs: + _ = step(params, state) + + self.assertEmpty(logs, 'soap.update recompiled on the second call') + + @parameterized.parameters( + {'b1': 0.9, 'b2': 0.999}, + {'b1': 0.95, 'b2': 0.99}, + ) + def test_nond_fallback_matches_adam(self, b1, b2): + """For 1D params, scale_by_soap should produce the same updates as Adam.""" + from optax._src import transform # pylint: disable=g-import-not-at-top + + params = jnp.array([-1.0, 2.0, 0.5]) + grads = jnp.array([0.1, -0.2, 0.3]) + eps = 1e-8 + + soap_tx = _soap.scale_by_soap(b1=b1, b2=b2, eps=eps) + # scale_by_soap normalizes b1/b2/eps to float32 internally; pass float32 + # here so both use identical float32 arithmetic for (1-b2) in the EMA. + adam_tx = transform.scale_by_adam( + b1=jnp.float32(b1), b2=jnp.float32(b2), eps=jnp.float32(eps) + ) + + soap_state = soap_tx.init(params) + adam_state = adam_tx.init(params) + + soap_updates, _ = soap_tx.update(grads, soap_state) + adam_updates, _ = adam_tx.update(grads, adam_state) + + self.assertTrue( + jnp.allclose(soap_updates, adam_updates, atol=1e-6), + msg=f'soap={soap_updates}, adam={adam_updates}', + ) + + def test_invalid_precondition_frequency_raises(self): + with self.assertRaisesRegex(ValueError, 'precondition_frequency'): + _soap.scale_by_soap(precondition_frequency=0) + with self.assertRaisesRegex(ValueError, 'precondition_frequency'): + _soap.scale_by_soap(precondition_frequency=-3) + + def test_weight_decay_applied(self): + """Verify that weight_decay causes params to shrink toward zero.""" + initial = jnp.ones((3, 3)) * 5.0 + solver = _soap.soap(learning_rate=1e-2, weight_decay=0.1) + params = initial + state = solver.init(params) + + zero_grads = jnp.zeros_like(params) + + @jax.jit + def step(p, s): + u, s = solver.update(zero_grads, s, p) + return update.apply_updates(p, u), s + + for _ in range(100): + params, state = step(params, state) + + self.assertTrue( + jnp.all(jnp.abs(params) < jnp.abs(initial)), + msg='weight decay did not reduce parameter magnitude', + ) + + def test_mu_dtype_reduces_mu_precision(self): + """mu should be stored in mu_dtype when specified.""" + params = jnp.ones((3, 3)) + tx = _soap.scale_by_soap(mu_dtype=jnp.bfloat16) + state = tx.init(params) + + mu_leaves = jax.tree.leaves(state.mu) + for leaf in mu_leaves: + self.assertEqual(leaf.dtype, jnp.bfloat16) + + +if __name__ == '__main__': + absltest.main() From 32adc1fda96e44c43dd2673726513fa55c946e32 Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Thu, 11 Jun 2026 06:34:13 +0530 Subject: [PATCH 2/6] contrib: add MARS optimizer (arXiv:2411.10438) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MARS (Unleashing the Power of Variance Reduction for Training Large Models) augments Adam with a scaled stochastic recursive momentum correction that reduces gradient variance across steps. For each parameter, the variance-reduced estimate c_t = g_t + γ·(β₁/(1-β₁))·(g_t - g_{t-1}) replaces the raw gradient before the standard Adam moment updates. c_t is optionally clipped to unit L2 norm to prevent the large effective gradient on the first step (when last_grad=0) from dominating. Setting γ=0 recovers Adam exactly. This is the approximate-MARS variant: g_{t-1} is the gradient from the previous step rather than a re-evaluation on the current mini-batch, keeping cost identical to Adam at one gradient evaluation per step. New files: optax/contrib/_mars.py – scale_by_mars and mars implementations optax/contrib/_mars_test.py – 13 unit tests (shapes, state tracking, γ=0/Adam equivalence, clipping on/off, variance-reduction direction check, convergence on quadratic and mixed params, JIT stability, weight decay, mu_dtype, count increment) Modified files: optax/contrib/__init__.py – export mars / scale_by_mars / ScaleByMARSState optax/contrib/_common_test.py – add mars to the shared optimizer test matrix --- optax/contrib/__init__.py | 3 + optax/contrib/_common_test.py | 1 + optax/contrib/_mars.py | 290 ++++++++++++++++++++++++++++++++++ optax/contrib/_mars_test.py | 283 +++++++++++++++++++++++++++++++++ 4 files changed, 577 insertions(+) create mode 100644 optax/contrib/_mars.py create mode 100644 optax/contrib/_mars_test.py diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 8b4761930..46521b84a 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -41,6 +41,9 @@ from optax.contrib._galore import GaLoreDimensionNumbers from optax.contrib._galore import GaLoreState from optax.contrib._galore import scale_by_galore +from optax.contrib._mars import mars +from optax.contrib._mars import ScaleByMARSState +from optax.contrib._mars import scale_by_mars from optax.contrib._madgrad import madgrad from optax.contrib._madgrad import MadgradState from optax.contrib._madgrad import scale_by_madgrad diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index ebb42860c..8f4a29fb5 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -50,6 +50,7 @@ {'opt_name': 'dog', 'opt_kwargs': {'learning_rate': 1.0}}, {'opt_name': 'dowg', 'opt_kwargs': {'learning_rate': 1.0}}, {'opt_name': 'madgrad', 'opt_kwargs': {'learning_rate': 1e-2}}, + {'opt_name': 'mars', 'opt_kwargs': {'learning_rate': 3e-3}}, {'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}}, {'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}}, {'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-2}}, diff --git a/optax/contrib/_mars.py b/optax/contrib/_mars.py new file mode 100644 index 000000000..0fc3e46de --- /dev/null +++ b/optax/contrib/_mars.py @@ -0,0 +1,290 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MARS: Unleashing the Power of Variance Reduction for Training Large Models. + +Implementation of "MARS: Unleashing the Power of Variance Reduction for +Training Large Models" (https://arxiv.org/abs/2411.10438) by Huizhuo Yuan, +Yifeng Liu, Shuang Wu, Xun Zhou, and Quanquan Gu. +""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import utils +from optax.transforms import _adding +from optax._src import transform +import optax.tree + + +class ScaleByMARSState(NamedTuple): + """State for the MARS variance-reduction preconditioner.""" + + count: jax.Array # shape=(), dtype=jnp.int32 + last_grad: base.Updates # gradient from the previous step, in float32 + mu: base.Updates # first-moment estimate (projected into EMA) + nu: base.Updates # second-moment estimate + + +def scale_by_mars( + b1: jax.typing.ArrayLike = 0.95, + b2: jax.typing.ArrayLike = 0.99, + eps: jax.typing.ArrayLike = 1e-8, + gamma: jax.typing.ArrayLike = 0.025, + clip_threshold: Optional[float] = 1.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""Scale updates using the MARS variance-reduction preconditioner. + + See :func:`optax.contrib.mars` for full details. + + Args: + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment estimates. + eps: Small constant added to the denominator for numerical stability. + gamma: Variance-reduction mixing coefficient. Controls how strongly the + optimizer corrects for the change in gradient direction since the last + step. Setting ``gamma=0`` recovers standard Adam. + clip_threshold: If set, the variance-reduced gradient :math:`c_t` is + rescaled to have at most this L2 norm before the moment updates. This + prevents large corrections early in training (when ``last_grad`` is + near zero) from dominating the update. Set to ``None`` to disable. + mu_dtype: Optional dtype for the first moment buffer. Useful for reducing + memory in mixed-precision training. If ``None``, inferred from params. + + Returns: + A :class:`optax.GradientTransformation`. + """ + # Normalize to float32 so Python-float and JAX-float32 closures compute + # (1 - b2) identically. Without this, the Python path gets 1-0.99=0.01 + # while the JAX float32 path gets 1-float32(0.99)=float32(0.009999...). + # inject_hyperparams always passes strongly-typed float32 values, so this + # cast ensures direct and inject paths are numerically identical. + b1 = jnp.asarray(b1, dtype=jnp.float32) + b2 = jnp.asarray(b2, dtype=jnp.float32) + eps = jnp.asarray(eps, dtype=jnp.float32) + gamma = jnp.asarray(gamma, dtype=jnp.float32) + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params: base.Params) -> ScaleByMARSState: + param_leaves = jax.tree.leaves(params) + if not param_leaves: + # Empty-params guard for tree_map_params compatibility. + empty = params + return ScaleByMARSState( + count=jnp.zeros([], jnp.int32), + last_grad=empty, + mu=empty, + nu=empty, + ) + return ScaleByMARSState( + count=jnp.zeros([], jnp.int32), + last_grad=optax.tree.zeros_like(params), + mu=optax.tree.zeros_like(params, dtype=mu_dtype), + nu=optax.tree.zeros_like(params), + ) + + def update_fn( + updates: base.Updates, + state: ScaleByMARSState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, ScaleByMARSState]: + del params + count_inc = numerics.safe_int32_increment(state.count) + + # Variance-reduced gradient: + # c_t = g_t + γ · (β₁/(1-β₁)) · (g_t - g_{t-1}) + # This is the approximate-MARS variant: g_{t-1} is the gradient from the + # previous step on a different mini-batch, not a re-evaluation on the same + # batch. The approximation is cheaper (one grad eval per step) and works + # well in practice. When γ=0, c_t = g_t and MARS is identical to Adam. + correction_scale = gamma * b1 / (1.0 - b1) + c = jax.tree.map( + lambda g, g_prev: ( + g.astype(jnp.float32) + + correction_scale + * (g.astype(jnp.float32) - g_prev.astype(jnp.float32)) + ), + updates, + state.last_grad, + ) + + # Clip the corrected gradient to at most `clip_threshold` in L2 norm. + # Without clipping, the very first step (when last_grad=0) amplifies the + # gradient by a factor of (1 + γβ₁/(1-β₁)), which can be ~1.5× for + # typical γ=0.025, β₁=0.95. For larger γ the amplification is stronger; + # clipping keeps training numerically stable regardless of γ. + if clip_threshold is not None: + c_norm = optax.tree.norm(c) + clip_scale = jnp.minimum( + jnp.ones([], dtype=jnp.float32), + jnp.asarray(clip_threshold, dtype=jnp.float32) / (c_norm + 1e-12), + ) + c = jax.tree.map(lambda ci: ci * clip_scale, c) + + mu_new = jax.tree.map( + lambda m, ci: b1 * m.astype(jnp.float32) + (1.0 - b1) * ci, + state.mu, + c, + ) + nu_new = jax.tree.map( + lambda v, ci: b2 * v.astype(jnp.float32) + (1.0 - b2) * jnp.square(ci), + state.nu, + c, + ) + + mu_hat = jax.tree.map(lambda m: m / (1.0 - b1**count_inc), mu_new) + nu_hat = jax.tree.map(lambda v: v / (1.0 - b2**count_inc), nu_new) + + new_updates = jax.tree.map( + lambda m, v, g: (m / (jnp.sqrt(v) + eps)).astype(g.dtype), + mu_hat, + nu_hat, + updates, + ) + + # Cast moments back to their stored dtypes so dtype is stable across steps. + # Using the stored tensor's dtype (not mu_dtype directly) handles the + # mu_dtype=None case, where mu was initialised with the param dtype. + mu_stored = jax.tree.map( + lambda m_new, m: m_new.astype(m.dtype), mu_new, state.mu + ) + nu_stored = jax.tree.map( + lambda v_new, v: v_new.astype(v.dtype), nu_new, state.nu + ) + + return new_updates, ScaleByMARSState( + count=count_inc, + # Store last_grad in the same dtype as the incoming gradient; the + # float32 promotion happens during the correction computation above. + last_grad=updates, + mu=mu_stored, + nu=nu_stored, + ) + + return base.GradientTransformation(init_fn, update_fn) + + +def mars( + learning_rate: base.ScalarOrSchedule, + b1: jax.typing.ArrayLike = 0.95, + b2: jax.typing.ArrayLike = 0.99, + eps: jax.typing.ArrayLike = 1e-8, + gamma: jax.typing.ArrayLike = 0.025, + clip_threshold: Optional[float] = 1.0, + weight_decay: jax.typing.ArrayLike = 0.0, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]] + ] = None, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""MARS: variance-reduced Adam for large-model training. + + MARS (arXiv:2411.10438) augments Adam with a scaled stochastic recursive + momentum correction that reduces gradient variance across steps. For each + parameter, the raw gradient :math:`g_t` is replaced by a variance-reduced + estimate :math:`c_t` before the Adam moment updates: + + .. math:: + + \begin{align*} + c_t &\leftarrow g_t + \gamma \frac{\beta_1}{1 - \beta_1} + (g_t - g_{t-1}) \\ + \tilde{c}_t &\leftarrow c_t \,/\, \max\!\bigl(1,\, \|c_t\|_2\bigr) \\ + m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \tilde{c}_t \\ + v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \tilde{c}_t^2 \\ + \hat{m}_t &\leftarrow m_t / (1 - \beta_1^t) \\ + \hat{v}_t &\leftarrow v_t / (1 - \beta_2^t) \\ + \Delta_t &\leftarrow \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) + \end{align*} + + This is the *approximate* MARS variant: :math:`g_{t-1}` is the gradient + from the previous step rather than a re-evaluation on the same mini-batch. + This costs one gradient evaluation per step (identical to Adam) while still + capturing most of the variance-reduction benefit. + + Setting :math:`\gamma = 0` recovers standard Adam exactly. + + .. note:: + MARS adds a ``last_grad`` buffer to the optimizer state, increasing memory + by one copy of the parameter tensors beyond standard Adam. For large models + this is the same overhead as adding a second optimizer slot (e.g. the + second moment in Adam). + + .. note:: + The paper reports best results on transformer language model training with + ``learning_rate=3e-3``, ``b1=0.95``, ``b2=0.99``, ``gamma=0.025``, and + ``weight_decay=0.1``. These differ noticeably from typical AdamW defaults; + hyperparameter transfer from AdamW is not straightforward. + + Args: + learning_rate: A global scaling factor, either fixed or a schedule; see + :func:`optax.scale_by_learning_rate`. + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment estimates. + eps: Small constant added to the denominator for numerical stability. + gamma: Variance-reduction mixing coefficient. Controls how strongly the + optimizer corrects for gradient direction changes between steps. The + paper uses ``gamma=0.025``; larger values give more aggressive variance + reduction but can destabilize training if gradients are noisy. + clip_threshold: Maximum L2 norm of the variance-reduced gradient + :math:`c_t` before moment updates. Prevents the large effective + gradient on the first step (when ``last_grad`` is zero) from causing + an outsized update. Defaults to ``1.0``; set to ``None`` to disable. + weight_decay: Optional :math:`\ell_2` regularization strength. + weight_decay_mask: A tree with the same structure as (or a prefix of) the + params pytree, or a callable that returns such a tree given the params. + Leaves should be booleans indicating which parameters to apply weight + decay to. + mu_dtype: Optional dtype for the first moment buffer; useful for reducing + memory in mixed-precision training. + + Returns: + A :class:`optax.GradientTransformation`. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def loss(params): + ... return jnp.sum(jnp.square(params['w'] - jnp.ones((4, 4)))) + >>> params = {'w': jnp.zeros((4, 4)), 'b': jnp.zeros(4)} + >>> solver = optax.contrib.mars(learning_rate=3e-3) + >>> state = solver.init(params) + >>> for _ in range(5): + ... grads = jax.grad(loss)(params) + ... updates, state = solver.update(grads, state, params) + ... params = optax.apply_updates(params, updates) + + References: + Yuan et al., `MARS: Unleashing the Power of Variance Reduction for + Training Large Models `_, 2024 + """ + return combine.chain( + scale_by_mars( + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + mu_dtype=mu_dtype, + ), + _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), + transform.scale_by_learning_rate(learning_rate), + ) diff --git a/optax/contrib/_mars_test.py b/optax/contrib/_mars_test.py new file mode 100644 index 000000000..94df9a282 --- /dev/null +++ b/optax/contrib/_mars_test.py @@ -0,0 +1,283 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.contrib._mars.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from optax._src import numerics +from optax._src import transform +from optax._src import update +from optax.contrib import _mars +import optax.tree + + +class ScaleByMARSTest(parameterized.TestCase): + + def test_state_shapes(self): + """last_grad, mu, and nu should have the same shape as params.""" + params = {'w': jnp.zeros((3, 4)), 'b': jnp.zeros((4,))} + tx = _mars.scale_by_mars() + state = tx.init(params) + + # jax.tree.leaves traverses dicts in sorted key order, so check via the + # state fields themselves rather than by assumed list index. + param_shapes = [leaf.shape for leaf in jax.tree.leaves(params)] + for field in ('last_grad', 'mu', 'nu'): + field_shapes = [leaf.shape for leaf in jax.tree.leaves(getattr(state, field))] + self.assertEqual(field_shapes, param_shapes, msg=f'{field} shapes mismatch') + + def test_last_grad_stored_after_update(self): + """After an update step the state should hold the gradient from that step.""" + params = jnp.array([1.0, -2.0, 0.5]) + grads = jnp.array([0.3, -0.1, 0.7]) + tx = _mars.scale_by_mars() + state = tx.init(params) + _, new_state = tx.update(grads, state) + # last_grad should now equal the gradient we just applied. + self.assertTrue( + jnp.allclose( + jax.tree.leaves(new_state.last_grad)[0], + grads.astype(jnp.float32), + ), + msg='last_grad not updated correctly', + ) + + def test_gamma_zero_matches_adam(self): + """With gamma=0 the variance-reduction term vanishes and MARS == Adam.""" + params = jnp.array([-1.0, 2.0, 0.5]) + grads = jnp.array([0.1, -0.2, 0.3]) + b1, b2, eps = 0.9, 0.999, 1e-8 + + mars_tx = _mars.scale_by_mars( + b1=b1, b2=b2, eps=eps, gamma=0.0, clip_threshold=None + ) + # scale_by_mars normalizes to float32 internally; pass float32 to Adam. + adam_tx = transform.scale_by_adam( + b1=jnp.float32(b1), b2=jnp.float32(b2), eps=jnp.float32(eps) + ) + + mars_state = mars_tx.init(params) + adam_state = adam_tx.init(params) + + mars_updates, _ = mars_tx.update(grads, mars_state) + adam_updates, _ = adam_tx.update(grads, adam_state) + + self.assertTrue( + jnp.allclose(mars_updates, adam_updates, atol=1e-6), + msg=f'mars={mars_updates}, adam={adam_updates}', + ) + + def test_clipping_reduces_large_c(self): + """When the corrected gradient norm exceeds clip_threshold it is rescaled.""" + # Use a large gamma so the correction term inflates c well past 1.0. + params = jnp.ones((4,)) + grads = jnp.ones((4,)) * 10.0 # large gradient, last_grad initialised to 0 + + tx = _mars.scale_by_mars(gamma=5.0, clip_threshold=1.0) + state = tx.init(params) + _, new_state = tx.update(grads, state) + + # Retrieve the mu after the first step — it was updated from a clipped c, + # so its L2 norm should be no larger than 1.0 (before bias correction). + mu_leaf = jax.tree.leaves(new_state.mu)[0] + # mu = (1-b1) * c_clipped; since c_clipped is unit norm, |mu| ≈ (1-b1) + b1 = 0.95 + self.assertLessEqual( + float(jnp.sum(jnp.square(mu_leaf))), + (1.0 - b1) ** 2 + 1e-5, + msg='clipping did not reduce the update magnitude', + ) + + def test_no_clipping_when_disabled(self): + """With clip_threshold=None, large c_t values are not clamped.""" + params = jnp.ones((2,)) + grads = jnp.ones((2,)) * 100.0 # very large gradient + + tx_clipped = _mars.scale_by_mars(gamma=1.0, clip_threshold=1.0) + tx_unclipped = _mars.scale_by_mars(gamma=1.0, clip_threshold=None) + state_c = tx_clipped.init(params) + state_u = tx_unclipped.init(params) + + u_clipped, _ = tx_clipped.update(grads, state_c) + u_unclipped, _ = tx_unclipped.update(grads, state_u) + + # Without clipping the update magnitude should be strictly larger. + self.assertGreater( + float(jnp.sum(jnp.abs(u_unclipped))), + float(jnp.sum(jnp.abs(u_clipped))), + msg='disabling clipping should increase update magnitude for large grads', + ) + + def test_variance_reduction_uses_last_grad(self): + """The second-step update should differ from a plain Adam step because + MARS subtracts the previous gradient in the c_t calculation.""" + params = jnp.zeros((3,)) + g1 = jnp.array([1.0, 0.0, 0.0]) + g2 = jnp.array([0.0, 1.0, 0.0]) + + # MARS step 2 with gamma > 0 + mars_tx = _mars.scale_by_mars(gamma=0.5, clip_threshold=None) + mars_state = mars_tx.init(params) + _, mars_state = mars_tx.update(g1, mars_state) + mars_upd, _ = mars_tx.update(g2, mars_state) + + # Adam step 2 (gamma=0 branch of MARS) + adam_tx = _mars.scale_by_mars(gamma=0.0, clip_threshold=None) + adam_state = adam_tx.init(params) + _, adam_state = adam_tx.update(g1, adam_state) + adam_upd, _ = adam_tx.update(g2, adam_state) + + # With gamma > 0 the direction of the update should differ because c_t + # for MARS includes -last_grad = -g1, which has a nonzero first component. + self.assertFalse( + jnp.allclose(mars_upd, adam_upd), + msg='gamma > 0 should cause MARS to differ from Adam on the second step', + ) + + def test_convergence_on_quadratic(self): + """MARS should converge to the minimum of a simple quadratic.""" + initial = jnp.zeros((4, 4), dtype=jnp.float32) + target = jnp.array( + [[1.0, -1.0, 0.5, 0.0], + [0.0, 2.0, -1.0, 0.5], + [-0.5, 1.0, 0.0, -1.0], + [1.0, 0.0, -0.5, 2.0]], dtype=jnp.float32, + ) + obj_fn = lambda p: jnp.sum(numerics.abs_sq(p - target)) + + solver = _mars.mars(learning_rate=3e-3, gamma=0.025) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(3000): + params, state = step(params, state) + + self.assertTrue( + jnp.allclose(params, target, atol=0.05), + msg=f'Max deviation: {jnp.max(jnp.abs(params - target)):.4f}', + ) + + def test_convergence_on_mixed_params(self): + """MARS handles dicts with params of different ranks.""" + initial = {'w': jnp.zeros((3, 3)), 'b': jnp.zeros((3,))} + target = { + 'w': jnp.array([[1.0, -1.0, 0.5], [0.0, 2.0, -1.0], [-0.5, 1.0, 0.0]]), + 'b': jnp.array([1.0, -1.0, 0.5]), + } + obj_fn = lambda p: ( + jnp.sum(numerics.abs_sq(p['w'] - target['w'])) + + jnp.sum(numerics.abs_sq(p['b'] - target['b'])) + ) + + solver = _mars.mars(learning_rate=3e-3, gamma=0.025) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(3000): + params, state = step(params, state) + + for key in target: + self.assertTrue( + jnp.allclose(params[key], target[key], atol=0.05), + msg=f"key={key}, max dev: {jnp.max(jnp.abs(params[key] - target[key])):.4f}", + ) + + def test_jit_no_recompilation(self): + """mars.update should not retrace on the second call.""" + from optax._src import test_utils # pylint: disable=g-import-not-at-top + + params = {'w': jnp.ones((3, 3)), 'b': jnp.ones(3)} + solver = _mars.mars(learning_rate=3e-3) + state = solver.init(params) + grads = jax.tree.map(jnp.ones_like, params) + + step = jax.jit(lambda p, s: solver.update(grads, s, p)) + _, state = step(params, state) + + with test_utils.log_compilations() as logs: + _ = step(params, state) + + self.assertEmpty(logs, 'mars.update recompiled on the second call') + + def test_weight_decay_applied(self): + """Weight decay should cause parameter magnitudes to decrease.""" + initial = jnp.ones((3, 3)) * 5.0 + solver = _mars.mars(learning_rate=3e-3, weight_decay=0.1) + params = initial + state = solver.init(params) + zero_grads = jnp.zeros_like(params) + + @jax.jit + def step(p, s): + u, s = solver.update(zero_grads, s, p) + return update.apply_updates(p, u), s + + for _ in range(100): + params, state = step(params, state) + + self.assertTrue( + jnp.all(jnp.abs(params) < jnp.abs(initial)), + msg='weight decay did not reduce parameter magnitude', + ) + + def test_mu_dtype_reduces_mu_precision(self): + """mu should be stored in mu_dtype when specified.""" + params = jnp.ones((3, 3)) + tx = _mars.scale_by_mars(mu_dtype=jnp.bfloat16) + state = tx.init(params) + + # Run one step so mu is non-zero. + grads = jnp.ones_like(params) + _, state = tx.update(grads, state) + + mu_leaves = jax.tree.leaves(state.mu) + for leaf in mu_leaves: + self.assertEqual(leaf.dtype, jnp.bfloat16) + + @parameterized.parameters( + {'gamma': 0.01, 'b1': 0.9, 'b2': 0.999}, + {'gamma': 0.05, 'b1': 0.95, 'b2': 0.99}, + ) + def test_state_count_increments(self, gamma, b1, b2): + """count should advance by one each update step.""" + params = jnp.zeros((2, 2)) + tx = _mars.scale_by_mars(b1=b1, b2=b2, gamma=gamma) + state = tx.init(params) + self.assertEqual(int(state.count), 0) + + grads = jnp.ones_like(params) + _, state = tx.update(grads, state) + self.assertEqual(int(state.count), 1) + + _, state = tx.update(grads, state) + self.assertEqual(int(state.count), 2) + + +if __name__ == '__main__': + absltest.main() From 2a4d2e9499443a918b4506b120b0020f06d4439e Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Thu, 18 Jun 2026 07:01:09 +0530 Subject: [PATCH 3/6] fix lint errors in _mars_test.py for CI Move test_utils import to file top (C0415), fix dict iteration to use .items() (C0206), and wrap all lines that exceeded 80 chars (E501). --- optax/contrib/_mars_test.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/optax/contrib/_mars_test.py b/optax/contrib/_mars_test.py index 94df9a282..374bccf87 100644 --- a/optax/contrib/_mars_test.py +++ b/optax/contrib/_mars_test.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp from optax._src import numerics +from optax._src import test_utils from optax._src import transform from optax._src import update from optax.contrib import _mars @@ -37,11 +38,15 @@ def test_state_shapes(self): # state fields themselves rather than by assumed list index. param_shapes = [leaf.shape for leaf in jax.tree.leaves(params)] for field in ('last_grad', 'mu', 'nu'): - field_shapes = [leaf.shape for leaf in jax.tree.leaves(getattr(state, field))] - self.assertEqual(field_shapes, param_shapes, msg=f'{field} shapes mismatch') + field_shapes = [ + leaf.shape for leaf in jax.tree.leaves(getattr(state, field)) + ] + self.assertEqual( + field_shapes, param_shapes, msg=f'{field} shapes mismatch' + ) def test_last_grad_stored_after_update(self): - """After an update step the state should hold the gradient from that step.""" + """After one update, state should hold the gradient from that step.""" params = jnp.array([1.0, -2.0, 0.5]) grads = jnp.array([0.3, -0.1, 0.7]) tx = _mars.scale_by_mars() @@ -82,7 +87,7 @@ def test_gamma_zero_matches_adam(self): ) def test_clipping_reduces_large_c(self): - """When the corrected gradient norm exceeds clip_threshold it is rescaled.""" + """When corrected gradient norm exceeds clip_threshold it is rescaled.""" # Use a large gamma so the correction term inflates c well past 1.0. params = jnp.ones((4,)) grads = jnp.ones((4,)) * 10.0 # large gradient, last_grad initialised to 0 @@ -119,7 +124,10 @@ def test_no_clipping_when_disabled(self): self.assertGreater( float(jnp.sum(jnp.abs(u_unclipped))), float(jnp.sum(jnp.abs(u_clipped))), - msg='disabling clipping should increase update magnitude for large grads', + msg=( + 'disabling clipping should increase update magnitude' + ' for large grads' + ), ) def test_variance_reduction_uses_last_grad(self): @@ -145,7 +153,10 @@ def test_variance_reduction_uses_last_grad(self): # for MARS includes -last_grad = -g1, which has a nonzero first component. self.assertFalse( jnp.allclose(mars_upd, adam_upd), - msg='gamma > 0 should cause MARS to differ from Adam on the second step', + msg=( + 'gamma > 0 should cause MARS to differ from Adam' + ' on the second step' + ), ) def test_convergence_on_quadratic(self): @@ -202,16 +213,17 @@ def step(params, state): for _ in range(3000): params, state = step(params, state) - for key in target: + for key, val in target.items(): self.assertTrue( - jnp.allclose(params[key], target[key], atol=0.05), - msg=f"key={key}, max dev: {jnp.max(jnp.abs(params[key] - target[key])):.4f}", + jnp.allclose(params[key], val, atol=0.05), + msg=( + f"key={key}," + f" max dev: {jnp.max(jnp.abs(params[key] - val)):.4f}" + ), ) def test_jit_no_recompilation(self): """mars.update should not retrace on the second call.""" - from optax._src import test_utils # pylint: disable=g-import-not-at-top - params = {'w': jnp.ones((3, 3)), 'b': jnp.ones(3)} solver = _mars.mars(learning_rate=3e-3) state = solver.init(params) From 56af3371e1cca5fafe2ba1536d5dff0e029fec45 Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Thu, 18 Jun 2026 08:59:14 +0530 Subject: [PATCH 4/6] fix lint errors in _soap.py and _soap_test.py for CI Move test_utils and transform imports to file top (C0415), fix dict iteration to use .items() (C0206), and wrap the 81-char docstring line in _soap.py (E501). --- optax/contrib/_soap.py | 4 ++-- optax/contrib/_soap_test.py | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/optax/contrib/_soap.py b/optax/contrib/_soap.py index bdcf8464d..7bbe95383 100644 --- a/optax/contrib/_soap.py +++ b/optax/contrib/_soap.py @@ -283,8 +283,8 @@ def soap( v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \tilde{G}_t^2 \\ \hat{m}_t &\leftarrow m_t / (1 - \beta_1^t) \\ \hat{v}_t &\leftarrow v_t / (1 - \beta_2^t) \\ - \Delta_t &\leftarrow Q_L \bigl(\hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) - \bigr) Q_R^\top + \Delta_t &\leftarrow Q_L \bigl( + \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) \bigr) Q_R^\top \end{align*} Parameters with fewer than 2 dimensions fall back to standard Adam. diff --git a/optax/contrib/_soap_test.py b/optax/contrib/_soap_test.py index 4b52b49d7..767b1ddd3 100644 --- a/optax/contrib/_soap_test.py +++ b/optax/contrib/_soap_test.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp from optax._src import numerics +from optax._src import test_utils +from optax._src import transform from optax._src import update from optax.contrib import _soap import optax.tree @@ -222,17 +224,15 @@ def step(params, state): for _ in range(3000): params, state = step(params, state) - for key in target: + for key, val in target.items(): self.assertTrue( - jnp.allclose(params[key], target[key], atol=0.05), + jnp.allclose(params[key], val, atol=0.05), msg=f"key={key}, max deviation:" - f" {jnp.max(jnp.abs(params[key] - target[key])):.4f}", + f" {jnp.max(jnp.abs(params[key] - val)):.4f}", ) def test_jit_no_recompilation(self): """optimizer.update should not retrace on the second call.""" - from optax._src import test_utils # pylint: disable=g-import-not-at-top - params = {'w': jnp.ones((3, 3)), 'b': jnp.ones(3)} solver = _soap.soap(learning_rate=1e-3) state = solver.init(params) @@ -252,8 +252,6 @@ def test_jit_no_recompilation(self): ) def test_nond_fallback_matches_adam(self, b1, b2): """For 1D params, scale_by_soap should produce the same updates as Adam.""" - from optax._src import transform # pylint: disable=g-import-not-at-top - params = jnp.array([-1.0, 2.0, 0.5]) grads = jnp.array([0.1, -0.2, 0.3]) eps = 1e-8 From 7f9cc1c026fe0498642b9aa98d18aca26a1c5894 Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Wed, 24 Jun 2026 16:45:05 +0530 Subject: [PATCH 5/6] fix pyrefly type errors in _soap.py Change count field type from jax.Array to jax.typing.ArrayLike to match the established pattern in transform.py (ScaleByAdamState). Add pyrefly ignore comments for the GradientTransformation and add_decayed_weights calls, matching the pattern used in alias.py. --- optax/contrib/_soap.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_soap.py b/optax/contrib/_soap.py index 7bbe95383..0c3cac306 100644 --- a/optax/contrib/_soap.py +++ b/optax/contrib/_soap.py @@ -35,7 +35,7 @@ class ScaleBySOAPState(NamedTuple): """State for the SOAP optimizer.""" - count: jax.Array # shape=(), dtype=jnp.int32 + count: jax.typing.ArrayLike # shape=(), dtype=jnp.int32 # Kronecker factors and their eigenbases, stored in float32. # Leaves have shape (m, m) / (n, n) for 2D params, or (0,) for others. left_factor: base.Updates @@ -245,6 +245,7 @@ def _update_single(g, l, r, q_l, q_r, mu, nu): nu=new_nu, ) + # pyrefly: ignore[bad-argument-type] return base.GradientTransformation(init_fn, update_fn) @@ -346,6 +347,7 @@ def soap( precondition_frequency=precondition_frequency, mu_dtype=mu_dtype, ), + # pyrefly: ignore[bad-argument-type] _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) From 09987eafe2fc15e5833c6f35f435e148f7446467 Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Wed, 24 Jun 2026 16:54:45 +0530 Subject: [PATCH 6/6] fix pyrefly type errors in _mars.py Change count field type from jax.Array to jax.typing.ArrayLike to match the established pattern in transform.py. Add pyrefly ignore comments for GradientTransformation and add_decayed_weights calls, matching the same pattern used in _soap.py and alias.py. --- optax/contrib/_mars.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_mars.py b/optax/contrib/_mars.py index 0fc3e46de..8a52c8331 100644 --- a/optax/contrib/_mars.py +++ b/optax/contrib/_mars.py @@ -35,7 +35,7 @@ class ScaleByMARSState(NamedTuple): """State for the MARS variance-reduction preconditioner.""" - count: jax.Array # shape=(), dtype=jnp.int32 + count: jax.typing.ArrayLike # shape=(), dtype=jnp.int32 last_grad: base.Updates # gradient from the previous step, in float32 mu: base.Updates # first-moment estimate (projected into EMA) nu: base.Updates # second-moment estimate @@ -178,6 +178,7 @@ def update_fn( nu=nu_stored, ) + # pyrefly: ignore[bad-argument-type] return base.GradientTransformation(init_fn, update_fn) @@ -285,6 +286,7 @@ def mars( clip_threshold=clip_threshold, mu_dtype=mu_dtype, ), + # pyrefly: ignore[bad-argument-type] _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), )