diff --git a/mars_benchmark.png b/mars_benchmark.png new file mode 100644 index 000000000..dbbf4c0a7 Binary files /dev/null and b/mars_benchmark.png differ diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..b15d020fa 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -41,6 +41,9 @@ from optax.contrib._galore import GaLoreDimensionNumbers from optax.contrib._galore import GaLoreState from optax.contrib._galore import scale_by_galore +from optax.contrib._mars import mars +from optax.contrib._mars import MarsState +from optax.contrib._mars import scale_by_mars from optax.contrib._madgrad import madgrad from optax.contrib._madgrad import MadgradState from optax.contrib._madgrad import scale_by_madgrad diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index 510ef3159..7149df175 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -50,6 +50,8 @@ {'opt_name': 'dog', 'opt_kwargs': {'learning_rate': 1.0}}, {'opt_name': 'dowg', 'opt_kwargs': {'learning_rate': 1.0}}, {'opt_name': 'madgrad', 'opt_kwargs': {'learning_rate': 1e-2}}, + {'opt_name': 'mars', 'opt_kwargs': {'learning_rate': 1e-3}}, + {'opt_name': 'mars', 'opt_kwargs': {'learning_rate': 1e-3, 'gamma': 1.0}}, {'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}}, {'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}}, {'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-2}}, diff --git a/optax/contrib/_mars.py b/optax/contrib/_mars.py new file mode 100644 index 000000000..b4b6ae694 --- /dev/null +++ b/optax/contrib/_mars.py @@ -0,0 +1,256 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MARS: Unleashing the Power of Variance Reduction for Training Large Models. + +Reference: + Hu et al., `MARS: Unleashing the Power of Variance Reduction for Training + Large Models `_, 2024. +""" + +from typing import Any, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils +import optax.tree + + +class MarsState(NamedTuple): + """State for the MARS optimizer.""" + + count: jax.Array # shape=(), dtype=jnp.int32 + mu: base.Updates # first moment (EMA of corrected gradients) + nu: base.Updates # second moment (EMA of squared corrected gradients) + prev_grad: base.Updates # g_{t-1}: gradient from the previous step + c_prev: base.Updates # c_{t-1}: corrected gradient from the previous step + + +def scale_by_mars( + gamma: float = 0.025, + b1: float = 0.9, + b2: float = 0.99, + eps: float = 1e-8, + mu_dtype: Optional[Any] = None, + *, + correction_clip: Optional[float] = None, + nesterov: bool = False, +) -> base.GradientTransformation: + r"""MARS variance-reduction gradient rescaling. + + Computes a STORM-style corrected gradient :math:`c_t` and then applies + Adam-style first- and second-moment accumulation on :math:`c_t` rather than + the raw gradient :math:`g_t`. The corrected gradient is: + + .. math:: + + c_t = g_t + (1 - \gamma)(c_{t-1} - g_{t-1}), \quad c_1 = g_1. + + With :math:`\gamma = 1` this reduces to plain Adam rescaling. Smaller + :math:`\gamma` gives stronger variance reduction at the cost of sensitivity + to gradient noise between consecutive steps. + + Args: + gamma: Variance-reduction coefficient :math:`\gamma \in (0, 1]`. The + authors recommend ``0.025`` for LLM pre-training. + b1: Exponential decay rate for the first moment (momentum). + b2: Exponential decay rate for the second moment. + eps: Small constant for numerical stability in the denominator. + mu_dtype: Optional dtype for the first-moment buffer. If ``None`` the + dtype is inferred from the parameters. + correction_clip: If set, clips the *correction term* + :math:`(1-\gamma)(c_{t-1} - g_{t-1})` by global norm before adding it + to :math:`g_t`. Recommended by the paper (Section 3.2) for stability. + nesterov: Whether to use Nesterov momentum for the first moment. + + Returns: + A :class:`optax.GradientTransformation`. + + .. seealso:: :func:`optax.contrib.mars` + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = optax.tree.zeros_like(params, dtype=mu_dtype) + nu = optax.tree.zeros_like(params) + prev_grad = optax.tree.zeros_like(params) + c_prev = optax.tree.zeros_like(params) + return MarsState( + count=jnp.zeros([], jnp.int32), + mu=mu, + nu=nu, + prev_grad=prev_grad, + c_prev=c_prev, + ) + + def update_fn(updates, state, params=None): + del params + count = state.count + g_t = updates + + # ── Compute the STORM-style corrected gradient ───────────────────────── + # c_t = g_t + (1 - gamma) * (c_{t-1} - g_{t-1}) + # At step 0 (count == 0) there is no previous information, so c_1 = g_1. + correction = jax.tree.map( + lambda c, g: (1.0 - gamma) * (c - g), state.c_prev, state.prev_grad + ) + # Zero out the correction on the very first step. + is_first_step = count == 0 + correction = jax.tree.map( + lambda corr: jnp.where(is_first_step, jnp.zeros_like(corr), corr), + correction, + ) + + if correction_clip is not None: + # Clip the correction term by its global norm (Section 3.2 of paper). + leaves = jax.tree.leaves(correction) + global_norm = jnp.sqrt( + sum(jnp.sum(leaf ** 2) for leaf in leaves) + 1e-12 + ) + scale = jnp.minimum(1.0, correction_clip / global_norm) + correction = jax.tree.map(lambda c: c * scale, correction) + + c_t = jax.tree.map(lambda g, corr: g + corr, g_t, correction) + + # ── Adam-style moment updates on c_t ────────────────────────────────── + mu = optax.tree.update_moment(c_t, state.mu, b1, 1) + nu = optax.tree.update_moment_per_elem_norm(c_t, state.nu, b2, 2) + count_inc = jnp.asarray(numerics.safe_increment(count)) + + mu_hat = optax.tree.bias_correction(mu, b1, count_inc) + nu_hat = optax.tree.bias_correction(nu, b2, count_inc) + + if nesterov: + mu_hat = jax.tree.map( + lambda m, c: b1 * m + (1.0 - b1) * c, mu_hat, c_t + ) + + updates_out = jax.tree.map( + lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat + ) + mu = optax.tree.cast(mu, mu_dtype) + + new_state = MarsState( + count=count_inc, + mu=mu, + nu=nu, + prev_grad=g_t, + c_prev=c_t, + ) + return updates_out, new_state + + return base.GradientTransformation(init_fn, update_fn) + + +def mars( + learning_rate: base.ScalarOrSchedule, + gamma: float = 0.025, + b1: float = 0.9, + b2: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 1e-4, + mask: Optional[Union[Any, base.PyTree]] = None, + mu_dtype: Optional[Any] = None, + *, + correction_clip: Optional[float] = None, + nesterov: bool = False, +) -> base.GradientTransformationExtraArgs: + r"""MARS: variance-reduction AdamW for training large models. + + MARS replaces the raw gradient in Adam with a STORM-style corrected gradient + :math:`c_t` that reduces variance across consecutive steps: + + .. math:: + + c_t = g_t + (1 - \gamma)(c_{t-1} - g_{t-1}), + + then applies AdamW-style updates: + + .. math:: + + \begin{align*} + m_t &= \beta_1 m_{t-1} + (1-\beta_1) c_t, \\ + v_t &= \beta_2 v_{t-1} + (1-\beta_2) c_t^2, \\ + \hat{m}_t &= m_t / (1 - \beta_1^t), \\ + \hat{v}_t &= v_t / (1 - \beta_2^t), \\ + \theta_t &= \theta_{t-1} + - \alpha \bigl(\hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) + + \lambda \theta_{t-1}\bigr). + \end{align*} + + MARS achieves the convergence rate of SGD-with-momentum while retaining + Adam's per-coordinate adaptivity. In large-scale LLM pre-training experiments + the authors report consistent improvements over AdamW. + + Args: + learning_rate: Global step size, either a scalar or a schedule. + gamma: Variance-reduction coefficient :math:`\gamma \in (0, 1]`. + ``gamma=1`` recovers AdamW exactly. The paper recommends ``0.025`` + for LLM pre-training; larger values (``0.5``–``1.0``) are safer for + fine-tuning where gradients are smoother. + b1: Exponential decay rate for the first moment. + b2: Exponential decay rate for the second moment. + eps: Small constant for numerical stability. + weight_decay: AdamW-style decoupled weight decay coefficient. + mask: A tree with the same structure as (or a prefix of) the params + pytree, or a callable that returns such a tree given the params/updates. + The leaves should be booleans; ``True`` leaves apply weight decay, + ``False`` leaves skip it. + mu_dtype: Optional dtype for the first-moment buffer. + correction_clip: If set, the correction term :math:`(1-\gamma)(c_{t-1} - + g_{t-1})` is clipped by this global norm before being added to + :math:`g_t`. Improves stability in early training (Section 3.2). + nesterov: Whether to use Nesterov momentum. + + Returns: + A :class:`optax.GradientTransformationExtraArgs`. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) + >>> solver = optax.contrib.mars(learning_rate=1e-3) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + >>> print('Objective function: {:.2f}'.format(f(params))) + Objective function: 13.97 + + References: + Hu et al., `MARS: Unleashing the Power of Variance Reduction for Training + Large Models `_, 2024. + """ + return combine.chain( + scale_by_mars( + gamma=gamma, + b1=b1, + b2=b2, + eps=eps, + mu_dtype=mu_dtype, + correction_clip=correction_clip, + nesterov=nesterov, + ), + transform.add_decayed_weights(weight_decay, mask=mask), + transform.scale_by_learning_rate(learning_rate), + ) diff --git a/optax/contrib/_mars_test.py b/optax/contrib/_mars_test.py new file mode 100644 index 000000000..3bc8f5319 --- /dev/null +++ b/optax/contrib/_mars_test.py @@ -0,0 +1,213 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MARS optimizer.""" + +import statistics + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import optax +from optax.contrib._mars import mars +from optax.contrib._mars import MarsState +from optax.contrib._mars import scale_by_mars + + +class ScaleByMarsTest(absltest.TestCase): + + def test_state_structure(self): + """MarsState has the expected fields after init.""" + params = jnp.ones((3,)) + tx = scale_by_mars() + state = tx.init(params) + self.assertIsInstance(state, MarsState) + self.assertEqual(state.count, 0) + self.assertEqual(state.mu.shape, params.shape) + self.assertEqual(state.nu.shape, params.shape) + self.assertEqual(state.prev_grad.shape, params.shape) + self.assertEqual(state.c_prev.shape, params.shape) + + def test_first_step_no_correction(self): + """At step 0 the correction term must be zero (c_1 = g_1).""" + params = jnp.ones((4,)) + tx = scale_by_mars(gamma=0.025) + state = tx.init(params) + grad = jnp.array([1.0, 2.0, -1.0, 0.5]) + _, new_state = tx.update(grad, state) + # c_prev should equal the gradient on the first step. + self.assertTrue(jnp.allclose(new_state.c_prev, grad)) + self.assertTrue(jnp.allclose(new_state.prev_grad, grad)) + + def test_gamma_one_recovers_adam_moments(self): + """With gamma=1 the correction vanishes and MARS reduces to Adam.""" + params = jnp.ones((3,)) + mars_tx = scale_by_mars(gamma=1.0, b1=0.9, b2=0.999, eps=1e-8) + adam_tx = optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8) + + mars_state = mars_tx.init(params) + adam_state = adam_tx.init(params) + + key = jax.random.PRNGKey(0) + for _ in range(5): + key, subkey = jax.random.split(key) + grad = jax.random.normal(subkey, params.shape) + mars_updates, mars_state = mars_tx.update(grad, mars_state) + adam_updates, adam_state = adam_tx.update(grad, adam_state) + + self.assertTrue( + jnp.allclose(mars_updates, adam_updates, atol=1e-6), + msg='MARS with gamma=1 should match Adam updates.', + ) + + def test_correction_reduces_moment_variance(self): + """Corrected gradient c_t should track the true gradient more smoothly.""" + params = jnp.ones((8,)) + tx = scale_by_mars(gamma=0.025) + state = tx.init(params) + + # Feed a noisy gradient sequence. + key = jax.random.PRNGKey(42) + c_norms = [] + g_norms = [] + for _ in range(20): + key, subkey = jax.random.split(key) + grad = jax.random.normal(subkey, params.shape) * 10.0 + _, state = tx.update(grad, state) + c_norms.append(float(jnp.linalg.norm(state.c_prev))) + g_norms.append(float(jnp.linalg.norm(grad))) + + # Corrected gradient norms should have lower std than raw gradient norms + # after the first couple of steps. + self.assertLess( + statistics.stdev(c_norms[3:]), + statistics.stdev(g_norms[3:]) + 1.0, # generous tolerance + msg='Corrected gradient should be smoother than raw gradient.', + ) + + def test_correction_clip(self): + """correction_clip should prevent exploding corrections.""" + params = jnp.ones((16,)) + tx = scale_by_mars(gamma=0.025, correction_clip=1.0) + state = tx.init(params) + # Prime the state with a large gradient. + large_grad = jnp.ones((16,)) * 1000.0 + _, state = tx.update(large_grad, state) + # Now send a very different gradient. + small_grad = jnp.zeros((16,)) + updates, _ = tx.update(small_grad, state) + # Update should be finite and bounded. + self.assertTrue(jnp.all(jnp.isfinite(updates))) + + def test_nesterov_flag(self): + """nesterov=True should produce different updates from nesterov=False.""" + params = jnp.ones((4,)) + tx_nes = scale_by_mars(gamma=0.5, nesterov=True) + tx_std = scale_by_mars(gamma=0.5, nesterov=False) + state_nes = tx_nes.init(params) + state_std = tx_std.init(params) + + key = jax.random.PRNGKey(7) + for _ in range(3): + key, subkey = jax.random.split(key) + grad = jax.random.normal(subkey, params.shape) + upd_nes, state_nes = tx_nes.update(grad, state_nes) + upd_std, state_std = tx_std.update(grad, state_std) + + # After several steps nesterov and standard updates should differ. + self.assertFalse( + jnp.allclose(upd_nes, upd_std), + msg='Nesterov and standard updates should differ.', + ) + + +class MarsOptimizerTest(parameterized.TestCase): + + @parameterized.parameters( + {'gamma': 1.0}, # Adam limit + {'gamma': 0.5}, # moderate correction + {'gamma': 0.025}, # paper default + ) + def test_descends_quadratic(self, gamma): + """mars() should reduce a simple quadratic objective.""" + params = jnp.array([3.0, -2.0, 1.0]) + solver = mars(learning_rate=1e-2, gamma=gamma, weight_decay=0.0) + state = solver.init(params) + + def loss(p): + return jnp.sum(p ** 2) + + initial_loss = loss(params) + for _ in range(50): + grad = jax.grad(loss)(params) + updates, state = solver.update(grad, state, params) + params = optax.apply_updates(params, updates) + + self.assertLess( + loss(params), + initial_loss, + msg=f'MARS (gamma={gamma}) should reduce the quadratic objective.', + ) + + def test_weight_decay_applied(self): + """weight_decay > 0 should shrink parameters over time.""" + params = jnp.ones((4,)) * 5.0 + solver_wd = mars(learning_rate=1e-3, weight_decay=0.1) + solver_no = mars(learning_rate=1e-3, weight_decay=0.0) + state_wd = solver_wd.init(params) + state_no = solver_no.init(params) + + zero_grad = jnp.zeros_like(params) + for _ in range(10): + upd_wd, state_wd = solver_wd.update(zero_grad, state_wd, params) + upd_no, state_no = solver_no.update(zero_grad, state_no, params) + params_wd = optax.apply_updates(params, upd_wd) + + # Parameters with weight decay should have smaller norm. + params_no = optax.apply_updates(params, upd_no) + self.assertLess( + jnp.linalg.norm(params_wd), + jnp.linalg.norm(params_no), + ) + + def test_correction_clip_stability(self): + """correction_clip should not cause NaNs even with very spiky gradients.""" + params = jnp.ones((8,)) + solver = mars(learning_rate=1e-3, correction_clip=0.1) + state = solver.init(params) + key = jax.random.PRNGKey(0) + for _ in range(30): + key, subkey = jax.random.split(key) + # Alternate between huge and tiny gradients to stress the correction. + grad = jax.random.normal(subkey, params.shape) * 1e3 + updates, state = solver.update(grad, state, params) + params = optax.apply_updates(params, updates) + self.assertTrue(jnp.all(jnp.isfinite(params))) + + def test_pytree_params(self): + """mars() should work with pytree (dict) parameters.""" + params = {'w': jnp.ones((3,)), 'b': jnp.zeros((2,))} + solver = mars(learning_rate=1e-3) + state = solver.init(params) + grads = jax.tree.map(jnp.ones_like, params) + updates, state = solver.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + jax.tree.map( + lambda p: self.assertTrue(jnp.all(jnp.isfinite(p))), new_params + ) + + +if __name__ == '__main__': + absltest.main()