diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..8b4761930 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -70,6 +70,9 @@ from optax.contrib._schedule_free import schedule_free_eval_params from optax.contrib._schedule_free import schedule_free_sgd from optax.contrib._schedule_free import ScheduleFreeState +from optax.contrib._soap import scale_by_soap +from optax.contrib._soap import ScaleBySOAPState +from optax.contrib._soap import soap from optax.contrib._sophia import hutchinson_estimator_diag_hessian from optax.contrib._sophia import HutchinsonState from optax.contrib._sophia import sophia diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index 510ef3159..ebb42860c 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -66,6 +66,10 @@ 'opt_name': 'sophia', 'opt_kwargs': {'learning_rate': 1e-2} }, + { + 'opt_name': 'soap', + 'opt_kwargs': {'learning_rate': 1e-2}, + }, { 'opt_name': 'galore', 'opt_kwargs': {'learning_rate': 1e-2, 'rank': 8} @@ -358,7 +362,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # inject_hyperparams. static_args = [] for uninjectable_hparam in ['warmup_steps', 'num_betas', 'clip_value_fn', - 'ns_steps', 'rank', 'update_proj_gap']: + 'ns_steps', 'rank', 'update_proj_gap', + 'precondition_frequency']: if uninjectable_hparam in inspect.signature(factory).parameters.keys(): static_args.append(uninjectable_hparam) static_args = tuple(static_args) diff --git a/optax/contrib/_soap.py b/optax/contrib/_soap.py new file mode 100644 index 000000000..0c3cac306 --- /dev/null +++ b/optax/contrib/_soap.py @@ -0,0 +1,353 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SOAP: Improving and Stabilizing Shampoo using Adam. + +Implementation of "SOAP: Improving and Stabilizing Shampoo using Adam" +(https://arxiv.org/abs/2409.11321) by Nikhil Vyas, Depen Morwani, Rosie Zhao, +Itai Shapira, David Brandfonbrener, Lucas Janson, and Sham Kakade. +""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils +from optax.transforms import _adding +import optax.tree + + +class ScaleBySOAPState(NamedTuple): + """State for the SOAP optimizer.""" + + count: jax.typing.ArrayLike # shape=(), dtype=jnp.int32 + # Kronecker factors and their eigenbases, stored in float32. + # Leaves have shape (m, m) / (n, n) for 2D params, or (0,) for others. + left_factor: base.Updates + right_factor: base.Updates + left_basis: base.Updates + right_basis: base.Updates + # Adam moment buffers maintained in the rotated subspace for 2D params, + # or in the original space for non-2D params. + mu: base.Updates + nu: base.Updates + + +def scale_by_soap( + b1: jax.typing.ArrayLike = 0.9, + b2: jax.typing.ArrayLike = 0.999, + eps: jax.typing.ArrayLike = 1e-8, + precondition_frequency: int = 10, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""Scale updates using the SOAP preconditioner. + + See :func:`optax.contrib.soap` for full details. + + Args: + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment and Kronecker factor estimates. + eps: Small constant added to the denominator for numerical stability. + precondition_frequency: Number of steps between eigenbasis recomputations. + Lower values track gradient geometry more closely at higher cost per step. + Must be a positive Python int (not a JAX-traced value). + mu_dtype: Optional dtype for the first moment buffer. Useful for + reducing memory in mixed-precision training. If ``None``, the dtype is + inferred from the parameters. + + Returns: + A :class:`optax.GradientTransformation`. + """ + if not isinstance(precondition_frequency, int) or precondition_frequency < 1: + raise ValueError( + f'`precondition_frequency` must be a positive int, got' + f' {precondition_frequency!r}.' + ) + + # Normalize to float32 so that Python-float and JAX-float32 closures compute + # (1 - b2) identically. Without this, Python-float b2=0.999 gives + # 1-b2=0.001 (Python arithmetic) while JAX float32(0.999) gives + # 1-b2=0.0009999871 (float32 arithmetic), causing Kronecker factors to differ + # and the null-space eigenvectors of R to be numerically arbitrary (any + # orthonormal basis of a degenerate subspace is valid, so even tiny matrix + # differences produce completely different eigenvectors). + b1 = jnp.asarray(b1, dtype=jnp.float32) + b2 = jnp.asarray(b2, dtype=jnp.float32) + eps = jnp.asarray(eps, dtype=jnp.float32) + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params: base.Params) -> ScaleBySOAPState: + param_leaves = jax.tree.leaves(params) + if not param_leaves: + # Called with placeholder params (e.g. from tree_map_params). Return an + # empty state that matches the empty tree structure. + empty = params + return ScaleBySOAPState( + count=jnp.zeros([], jnp.int32), + left_factor=empty, + right_factor=empty, + left_basis=empty, + right_basis=empty, + mu=empty, + nu=empty, + ) + + def _init_factors(p): + if p.ndim == 2: + m, n = p.shape + return ( + jnp.zeros((m, m), dtype=jnp.float32), + jnp.zeros((n, n), dtype=jnp.float32), + jnp.eye(m, dtype=jnp.float32), + jnp.eye(n, dtype=jnp.float32), + ) + return ( + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + jnp.zeros((0,), dtype=jnp.float32), + ) + + factor_tuples = jax.tree.map(_init_factors, params) + lf, rf, lb, rb = jax.tree.transpose( + jax.tree.structure(params), + jax.tree.structure((0, 0, 0, 0)), + factor_tuples, + ) + + return ScaleBySOAPState( + count=jnp.zeros([], jnp.int32), + left_factor=lf, + right_factor=rf, + left_basis=lb, + right_basis=rb, + mu=optax.tree.zeros_like(params, dtype=mu_dtype), + nu=optax.tree.zeros_like(params), + ) + + def update_fn( + updates: base.Updates, + state: ScaleBySOAPState, + params: Optional[base.Params] = None, + ) -> tuple[base.Updates, ScaleBySOAPState]: + del params + count_inc = numerics.safe_int32_increment(state.count) + should_update_bases = jnp.equal( + jnp.mod(state.count, precondition_frequency), 0 + ) + + def _update_2d(g, l, r, q_l, q_r, mu, nu): + g_f32 = g.astype(jnp.float32) + + l_new = b2 * l + (1.0 - b2) * (g_f32 @ g_f32.T) + r_new = b2 * r + (1.0 - b2) * (g_f32.T @ g_f32) + + def _recompute_bases(args): + l_cur, r_cur = args + # Symmetrize to guard against accumulated floating-point asymmetry. + _, new_q_l = jnp.linalg.eigh((l_cur + l_cur.T) * 0.5) + _, new_q_r = jnp.linalg.eigh((r_cur + r_cur.T) * 0.5) + return new_q_l, new_q_r + + q_l_new, q_r_new = jax.lax.cond( + should_update_bases, + _recompute_bases, + lambda args: (q_l, q_r), + (l_new, r_new), + ) + + g_proj = q_l_new.T @ g_f32 @ q_r_new + + mu_f32 = mu.astype(jnp.float32) + nu_f32 = nu.astype(jnp.float32) + + mu_new = b1 * mu_f32 + (1.0 - b1) * g_proj + nu_new = b2 * nu_f32 + (1.0 - b2) * jnp.square(g_proj) + + mu_hat = mu_new / (1.0 - b1**count_inc) + nu_hat = nu_new / (1.0 - b2**count_inc) + u_proj = mu_hat / (jnp.sqrt(nu_hat) + eps) + + u = (q_l_new @ u_proj @ q_r_new.T).astype(g.dtype) + return ( + u, + l_new, + r_new, + q_l_new, + q_r_new, + mu_new.astype(mu.dtype), + nu_new.astype(nu.dtype), + ) + + def _update_nond(g, mu, nu): + g_f32 = g.astype(jnp.float32) + mu_new = b1 * mu.astype(jnp.float32) + (1.0 - b1) * g_f32 + nu_new = b2 * nu.astype(jnp.float32) + (1.0 - b2) * jnp.square(g_f32) + mu_hat = mu_new / (1.0 - b1**count_inc) + nu_hat = nu_new / (1.0 - b2**count_inc) + u = (mu_hat / (jnp.sqrt(nu_hat) + eps)).astype(g.dtype) + return u, mu_new.astype(mu.dtype), nu_new.astype(nu.dtype) + + def _update_single(g, l, r, q_l, q_r, mu, nu): + if g.ndim == 2: + u, l_new, r_new, q_l_new, q_r_new, mu_new, nu_new = _update_2d( + g, l, r, q_l, q_r, mu, nu + ) + else: + u, mu_new, nu_new = _update_nond(g, mu, nu) + l_new, r_new, q_l_new, q_r_new = l, r, q_l, q_r + return u, l_new, r_new, q_l_new, q_r_new, mu_new, nu_new + + result_tuples = jax.tree.map( + _update_single, + updates, + state.left_factor, + state.right_factor, + state.left_basis, + state.right_basis, + state.mu, + state.nu, + ) + + new_updates, new_lf, new_rf, new_lb, new_rb, new_mu, new_nu = ( + jax.tree.transpose( + jax.tree.structure(updates), + jax.tree.structure((0, 0, 0, 0, 0, 0, 0)), + result_tuples, + ) + ) + + new_mu = optax.tree.cast(new_mu, mu_dtype) + + return new_updates, ScaleBySOAPState( + count=count_inc, + left_factor=new_lf, + right_factor=new_rf, + left_basis=new_lb, + right_basis=new_rb, + mu=new_mu, + nu=new_nu, + ) + + # pyrefly: ignore[bad-argument-type] + return base.GradientTransformation(init_fn, update_fn) + + +def soap( + learning_rate: base.ScalarOrSchedule, + b1: jax.typing.ArrayLike = 0.9, + b2: jax.typing.ArrayLike = 0.999, + eps: jax.typing.ArrayLike = 1e-8, + weight_decay: jax.typing.ArrayLike = 0.0, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]] + ] = None, + precondition_frequency: int = 10, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""SOAP: Improving and Stabilizing Shampoo using Adam. + + SOAP combines the full-matrix preconditioning of Shampoo with the adaptive + moment estimation of Adam. For each 2D weight matrix :math:`W \in + \mathbb{R}^{m \times n}`, it maintains Kronecker-factor matrices whose + eigenbases define a rotation of the gradient. Adam's moment buffers are then + maintained in this rotated space, so the effective preconditioner is + Kronecker-structured but the per-coordinate adaptivity comes from Adam. + + For a 2D parameter with gradient :math:`G_t`: + + .. math:: + + \begin{align*} + L_t &\leftarrow \beta_2 L_{t-1} + (1 - \beta_2) G_t G_t^\top \\ + R_t &\leftarrow \beta_2 R_{t-1} + (1 - \beta_2) G_t^\top G_t \\ + Q_L, Q_R &\leftarrow \text{eigh}(L_t),\, \text{eigh}(R_t) + \quad (\text{every } k \text{ steps}) \\ + \tilde{G}_t &\leftarrow Q_L^\top G_t Q_R \\ + m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \tilde{G}_t \\ + v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \tilde{G}_t^2 \\ + \hat{m}_t &\leftarrow m_t / (1 - \beta_1^t) \\ + \hat{v}_t &\leftarrow v_t / (1 - \beta_2^t) \\ + \Delta_t &\leftarrow Q_L \bigl( + \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) \bigr) Q_R^\top + \end{align*} + + Parameters with fewer than 2 dimensions fall back to standard Adam. + + .. note:: + SOAP stores left and right Kronecker factors and their eigenbases for each + 2D parameter, introducing memory overhead of :math:`O(m^2 + n^2)` per + parameter on top of the :math:`O(mn)` Adam moments. For very large weight + matrices this can be substantial; consider using it selectively via + :func:`optax.masked`. + + .. note:: + Eigenbasis recomputation via ``eigh`` adds per-step cost every + ``precondition_frequency`` steps. The default of 10 balances tracking + quality against compute. For large layers, increase this to 50–100. + + Args: + learning_rate: A global scaling factor, either fixed or a schedule; see + :func:`optax.scale_by_learning_rate`. + b1: Decay rate for the first moment (momentum) estimates. + b2: Decay rate for the second moment and Kronecker factor estimates. + eps: Small constant added to the denominator for numerical stability. + weight_decay: Optional :math:`\ell_2` regularization strength. + weight_decay_mask: A tree with the same structure as (or a prefix of) the + params pytree, or a callable that returns such a tree given the params. + Leaves should be booleans indicating which parameters to apply weight + decay to. + precondition_frequency: Number of steps between eigenbasis recomputations + from the Kronecker factors. Must be a positive Python int. + mu_dtype: Optional dtype for the first moment buffer; useful for reducing + memory in mixed-precision training. + + Returns: + A :class:`optax.GradientTransformation`. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def loss(params): + ... return jnp.sum(jnp.square(params['w'] - jnp.ones((4, 4)))) + >>> params = {'w': jnp.zeros((4, 4)), 'b': jnp.zeros(4)} + >>> solver = optax.contrib.soap(learning_rate=1e-2) + >>> state = solver.init(params) + >>> for _ in range(5): + ... grads = jax.grad(loss)(params) + ... updates, state = solver.update(grads, state, params) + ... params = optax.apply_updates(params, updates) + + References: + Vyas et al., `SOAP: Improving and Stabilizing Shampoo using Adam + `_, 2024 + """ + return combine.chain( + scale_by_soap( + b1=b1, + b2=b2, + eps=eps, + precondition_frequency=precondition_frequency, + mu_dtype=mu_dtype, + ), + # pyrefly: ignore[bad-argument-type] + _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), + transform.scale_by_learning_rate(learning_rate), + ) diff --git a/optax/contrib/_soap_test.py b/optax/contrib/_soap_test.py new file mode 100644 index 000000000..767b1ddd3 --- /dev/null +++ b/optax/contrib/_soap_test.py @@ -0,0 +1,317 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.contrib._soap.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from optax._src import numerics +from optax._src import test_utils +from optax._src import transform +from optax._src import update +from optax.contrib import _soap +import optax.tree + + +def _parabola_2d(dtype=jnp.float32): + initial = jnp.zeros((4, 4), dtype=dtype) + target = jnp.array( + [[1.0, 2.0, -1.0, 0.5], + [-1.0, 0.0, 2.0, 1.0], + [0.5, -0.5, 1.0, -1.0], + [0.0, 1.0, -0.5, 2.0]], dtype=dtype, + ) + obj_fn = lambda p: jnp.sum(numerics.abs_sq(p - target)) + return initial, target, obj_fn + + +def _mixed_params(dtype=jnp.float32): + """Dict with one 2D param (SOAP path) and one 1D param (Adam fallback).""" + initial = { + 'w': jnp.zeros((3, 3), dtype=dtype), + 'b': jnp.zeros((3,), dtype=dtype), + } + target = { + 'w': jnp.array([[1.0, -1.0, 0.5], [0.0, 2.0, -1.0], [-0.5, 1.0, 0.0]], + dtype=dtype), + 'b': jnp.array([1.0, -1.0, 0.5], dtype=dtype), + } + + def obj_fn(params): + return jnp.sum(numerics.abs_sq(params['w'] - target['w'])) + jnp.sum( + numerics.abs_sq(params['b'] - target['b']) + ) + + return initial, target, obj_fn + + +class ScaleBySOAPTest(parameterized.TestCase): + + def test_state_shapes_2d_param(self): + """Kronecker factors and bases have the expected shapes for 2D params.""" + m, n = 5, 3 + params = jnp.zeros((m, n)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + leaves = jax.tree.leaves(state.left_factor) + self.assertEqual(leaves[0].shape, (m, m)) + leaves = jax.tree.leaves(state.right_factor) + self.assertEqual(leaves[0].shape, (n, n)) + leaves = jax.tree.leaves(state.left_basis) + self.assertEqual(leaves[0].shape, (m, m)) + leaves = jax.tree.leaves(state.right_basis) + self.assertEqual(leaves[0].shape, (n, n)) + leaves = jax.tree.leaves(state.mu) + self.assertEqual(leaves[0].shape, (m, n)) + leaves = jax.tree.leaves(state.nu) + self.assertEqual(leaves[0].shape, (m, n)) + + def test_state_shapes_1d_param(self): + """Non-2D params use empty placeholder arrays for the factor fields.""" + params = jnp.zeros((7,)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + for field_name in ('left_factor', 'right_factor', 'left_basis', + 'right_basis'): + leaves = jax.tree.leaves(getattr(state, field_name)) + self.assertEqual(leaves[0].shape, (0,)) + + leaves = jax.tree.leaves(state.mu) + self.assertEqual(leaves[0].shape, (7,)) + + def test_state_shapes_mixed_params(self): + m, n, d = 4, 3, 5 + params = {'w': jnp.zeros((m, n)), 'b': jnp.zeros((d,))} + tx = _soap.scale_by_soap() + state = tx.init(params) + + lf_leaves = jax.tree.leaves(state.left_factor) + rf_leaves = jax.tree.leaves(state.right_factor) + # jax.tree.leaves traversal order is sorted by key + self.assertEqual(lf_leaves[0].shape, (0,)) # 'b' is 1D + self.assertEqual(lf_leaves[1].shape, (m, m)) # 'w' is 2D + self.assertEqual(rf_leaves[0].shape, (0,)) + self.assertEqual(rf_leaves[1].shape, (n, n)) + + def test_bases_are_orthogonal_after_update(self): + """Eigenbases must satisfy Q Q^T ≈ I after a gradient step.""" + params = jnp.ones((4, 4)) + tx = _soap.scale_by_soap(precondition_frequency=1) + state = tx.init(params) + + grads = jax.random.normal(jax.random.PRNGKey(0), params.shape) + _, new_state = tx.update(grads, state) + + q_l = jax.tree.leaves(new_state.left_basis)[0] + q_r = jax.tree.leaves(new_state.right_basis)[0] + n = q_l.shape[0] + + self.assertEqual(q_l.shape, (n, n)) + self.assertEqual(q_r.shape, (params.shape[1], params.shape[1])) + + eye_l = q_l.T @ q_l + eye_r = q_r.T @ q_r + self.assertTrue(jnp.allclose(eye_l, jnp.eye(n), atol=1e-5)) + self.assertTrue(jnp.allclose(eye_r, jnp.eye(q_r.shape[0]), atol=1e-5)) + + def test_kronecker_factors_are_symmetric(self): + """Left and right Kronecker factors must remain symmetric.""" + params = jnp.zeros((5, 3)) + tx = _soap.scale_by_soap() + state = tx.init(params) + + key = jax.random.PRNGKey(42) + for _ in range(5): + key, subkey = jax.random.split(key) + grads = jax.random.normal(subkey, params.shape) + _, state = tx.update(grads, state) + + l = jax.tree.leaves(state.left_factor)[0] + r = jax.tree.leaves(state.right_factor)[0] + self.assertTrue(jnp.allclose(l, l.T, atol=1e-6)) + self.assertTrue(jnp.allclose(r, r.T, atol=1e-6)) + + def test_precondition_frequency_respected(self): + """Eigenbases should only change at multiples of precondition_frequency. + + With precondition_frequency=k and count starting at 0: + - update call i sees count=i, triggers eigh when i % k == 0. + - So eigh runs at calls 0, k, 2k, ... (0-indexed). + - bases_at_step[i] holds the basis after call i. + - bases_at_step[0..k-1] are all identical (refresh only at call 0). + - bases_at_step[k] is fresh (refresh at call k). + """ + k = 5 + params = jnp.ones((3, 3)) + tx = _soap.scale_by_soap(precondition_frequency=k) + state = tx.init(params) + + key = jax.random.PRNGKey(7) + bases_at_step = [] + for _ in range(k + 2): + key, subkey = jax.random.split(key) + grads = jax.random.normal(subkey, params.shape) + _, state = tx.update(grads, state) + q_l = jax.tree.leaves(state.left_basis)[0].copy() + bases_at_step.append(q_l) + + # Steps 1..k-1 must share the same basis as step 0 (no refresh in between). + for i in range(1, k): + self.assertTrue( + jnp.allclose(bases_at_step[i], bases_at_step[0]), + msg=f'basis unexpectedly changed at step {i}', + ) + + # Step k triggers a refresh; the new basis should differ from the previous + # one (which was computed from a single gradient, making it very unlikely + # to match the new one computed from k accumulated steps). + self.assertFalse( + jnp.allclose(bases_at_step[k], bases_at_step[k - 1]), + msg='basis did not change at the expected precondition step', + ) + + def test_convergence_on_2d_quadratic(self): + """SOAP converges to the minimum of a simple quadratic over 2D params.""" + initial, target, obj_fn = _parabola_2d() + + solver = _soap.soap(learning_rate=1e-2, precondition_frequency=5) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(2000): + params, state = step(params, state) + + self.assertTrue( + jnp.allclose(params, target, atol=0.05), + msg=f'Max deviation: {jnp.max(jnp.abs(params - target)):.4f}', + ) + + def test_convergence_on_mixed_params(self): + """SOAP handles dicts with both 2D and 1D params.""" + initial, target, obj_fn = _mixed_params() + + solver = _soap.soap(learning_rate=1e-2, precondition_frequency=5) + params = initial + state = solver.init(params) + + @jax.jit + def step(params, state): + grads = jax.grad(obj_fn)(params) + updates, state = solver.update(grads, state, params) + return update.apply_updates(params, updates), state + + for _ in range(3000): + params, state = step(params, state) + + for key, val in target.items(): + self.assertTrue( + jnp.allclose(params[key], val, atol=0.05), + msg=f"key={key}, max deviation:" + f" {jnp.max(jnp.abs(params[key] - val)):.4f}", + ) + + def test_jit_no_recompilation(self): + """optimizer.update should not retrace on the second call.""" + params = {'w': jnp.ones((3, 3)), 'b': jnp.ones(3)} + solver = _soap.soap(learning_rate=1e-3) + state = solver.init(params) + grads = jax.tree.map(jnp.ones_like, params) + + step = jax.jit(lambda p, s: solver.update(grads, s, p)) + _, state = step(params, state) + + with test_utils.log_compilations() as logs: + _ = step(params, state) + + self.assertEmpty(logs, 'soap.update recompiled on the second call') + + @parameterized.parameters( + {'b1': 0.9, 'b2': 0.999}, + {'b1': 0.95, 'b2': 0.99}, + ) + def test_nond_fallback_matches_adam(self, b1, b2): + """For 1D params, scale_by_soap should produce the same updates as Adam.""" + params = jnp.array([-1.0, 2.0, 0.5]) + grads = jnp.array([0.1, -0.2, 0.3]) + eps = 1e-8 + + soap_tx = _soap.scale_by_soap(b1=b1, b2=b2, eps=eps) + # scale_by_soap normalizes b1/b2/eps to float32 internally; pass float32 + # here so both use identical float32 arithmetic for (1-b2) in the EMA. + adam_tx = transform.scale_by_adam( + b1=jnp.float32(b1), b2=jnp.float32(b2), eps=jnp.float32(eps) + ) + + soap_state = soap_tx.init(params) + adam_state = adam_tx.init(params) + + soap_updates, _ = soap_tx.update(grads, soap_state) + adam_updates, _ = adam_tx.update(grads, adam_state) + + self.assertTrue( + jnp.allclose(soap_updates, adam_updates, atol=1e-6), + msg=f'soap={soap_updates}, adam={adam_updates}', + ) + + def test_invalid_precondition_frequency_raises(self): + with self.assertRaisesRegex(ValueError, 'precondition_frequency'): + _soap.scale_by_soap(precondition_frequency=0) + with self.assertRaisesRegex(ValueError, 'precondition_frequency'): + _soap.scale_by_soap(precondition_frequency=-3) + + def test_weight_decay_applied(self): + """Verify that weight_decay causes params to shrink toward zero.""" + initial = jnp.ones((3, 3)) * 5.0 + solver = _soap.soap(learning_rate=1e-2, weight_decay=0.1) + params = initial + state = solver.init(params) + + zero_grads = jnp.zeros_like(params) + + @jax.jit + def step(p, s): + u, s = solver.update(zero_grads, s, p) + return update.apply_updates(p, u), s + + for _ in range(100): + params, state = step(params, state) + + self.assertTrue( + jnp.all(jnp.abs(params) < jnp.abs(initial)), + msg='weight decay did not reduce parameter magnitude', + ) + + def test_mu_dtype_reduces_mu_precision(self): + """mu should be stored in mu_dtype when specified.""" + params = jnp.ones((3, 3)) + tx = _soap.scale_by_soap(mu_dtype=jnp.bfloat16) + state = tx.init(params) + + mu_leaves = jax.tree.leaves(state.mu) + for leaf in mu_leaves: + self.assertEqual(leaf.dtype, jnp.bfloat16) + + +if __name__ == '__main__': + absltest.main()