diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 4b6ff23d2..7c8da43a2 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -826,9 +826,9 @@ def adamw( def adan( learning_rate: base.ScalarOrSchedule, - b1: jax.typing.ArrayLike = 0.98, - b2: jax.typing.ArrayLike = 0.92, - b3: jax.typing.ArrayLike = 0.99, + b1: jax.typing.ArrayLike = 0.02, + b2: jax.typing.ArrayLike = 0.08, + b3: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 1e-8, weight_decay: base.ScalarOrSchedule = 0.0, @@ -921,8 +921,8 @@ def adan( Objective function: 1.28E+01 Objective function: 1.17E+01 Objective function: 1.07E+01 - Objective function: 9.68E+00 - Objective function: 8.76E+00 + Objective function: 9.69E+00 + Objective function: 8.77E+00 References: Xie et al, `Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing diff --git a/optax/_src/transform.py b/optax/_src/transform.py index ec2e8a2b5..0d46aba76 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -621,9 +621,9 @@ class ScaleByAdanState(NamedTuple): def scale_by_adan( - b1: jax.typing.ArrayLike = 0.98, - b2: jax.typing.ArrayLike = 0.92, - b3: jax.typing.ArrayLike = 0.99, + b1: jax.typing.ArrayLike = 0.02, + b2: jax.typing.ArrayLike = 0.08, + b3: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 0.0, ) -> base.GradientTransformation: @@ -662,16 +662,16 @@ def update_fn(updates, state, params=None): optax.tree.zeros_like(g), optax.tree.sub(g, state.g), ) - m = optax.tree.update_moment(g, state.m, b1, 1) - v = optax.tree.update_moment(diff, state.v, b2, 1) + m = optax.tree.update_moment(g, state.m, 1 - b1, 1) + v = optax.tree.update_moment(diff, state.v, 1 - b2, 1) sq = optax.tree.add_scale(g, 1 - b2, diff) - n = optax.tree.update_moment_per_elem_norm(sq, state.n, b3, 2) + n = optax.tree.update_moment_per_elem_norm(sq, state.n, 1 - b3, 2) t = numerics.safe_increment(state.t) - m_hat = optax.tree.bias_correction(m, b1, t) - v_hat = optax.tree.bias_correction(v, b2, t) - n_hat = optax.tree.bias_correction(n, b3, t) + m_hat = optax.tree.bias_correction(m, 1 - b1, t) + v_hat = optax.tree.bias_correction(v, 1 - b2, t) + n_hat = optax.tree.bias_correction(n, 1 - b3, t) u = optax.tree.add_scale(m_hat, 1 - b2, v_hat) denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index f521045ff..00340f4bf 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -73,6 +73,39 @@ def test_scalers(self, scaler_constr): test_utils.assert_tree_all_finite((params, updates, state)) test_utils.assert_trees_all_equal_shapes(params, updates) + def test_adan_matches_documented_update(self): + # Reference implementation of the equations documented in `optax.adan`. + b1, b2, b3 = 0.02, 0.08, 0.01 + eps, eps_root = 1e-8, 0.0 + grads = [ + jnp.array([0.5, -1.5, 2.0]), + jnp.array([0.3, -1.0, 2.5]), + jnp.array([-0.2, 0.7, 1.0]), + ] + + tx = transform.scale_by_adan( + b1=b1, b2=b2, b3=b3, eps=eps, eps_root=eps_root + ) + state = tx.init(grads[0]) + + m = jnp.zeros(3) + v = jnp.zeros(3) + n = jnp.zeros(3) + g_prev = jnp.zeros(3) + for step, g in enumerate(grads, start=1): + diff = jnp.zeros(3) if step == 1 else g - g_prev + m = (1 - b1) * m + b1 * g + v = (1 - b2) * v + b2 * diff + n = (1 - b3) * n + b3 * (g + (1 - b2) * diff) ** 2 + m_hat = m / (1 - (1 - b1) ** step) + v_hat = v / (1 - (1 - b2) ** step) + n_hat = n / (1 - (1 - b3) ** step) + expected = (m_hat + (1 - b2) * v_hat) / (jnp.sqrt(n_hat + eps_root) + eps) + g_prev = g + + updates, state = tx.update(g, state, None) + test_utils.assert_trees_all_close(updates, expected, atol=1e-4, rtol=1e-4) + def test_apply_every(self): # The frequency of the application of sgd k = 4