From 24fab282d752abfb482ece8e3bf084a8233652ea Mon Sep 17 00:00:00 2001 From: Ali Zulfiqar Date: Fri, 26 Jun 2026 23:59:12 +0500 Subject: [PATCH 1/4] Fix incorrect (1 - b2) coefficient in scale_by_adan The Adan update documented in `optax.adan` scales the gradient-difference term in both the squared accumulator `n` and the update direction by (1 - beta_2). optax parameterizes the moment EMAs via `update_moment` with decay `b2`, so b2 == (1 - beta_2); the two `add_scale` calls must therefore use `b2`, not `1 - b2`. With the default b2=0.92 the previous code applied a weight of 0.08 instead of 0.92, diverging from both the documented equations and the reference algorithm (Xie et al., 2022) from the second step onward. Add a test comparing scale_by_adan against an independent reference implementation of the documented equations over several steps. --- optax/_src/transform.py | 4 ++-- optax/_src/transform_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index ec2e8a2b5..082cb5938 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -665,7 +665,7 @@ def update_fn(updates, state, params=None): m = optax.tree.update_moment(g, state.m, b1, 1) v = optax.tree.update_moment(diff, state.v, b2, 1) - sq = optax.tree.add_scale(g, 1 - b2, diff) + sq = optax.tree.add_scale(g, b2, diff) n = optax.tree.update_moment_per_elem_norm(sq, state.n, b3, 2) t = numerics.safe_increment(state.t) @@ -673,7 +673,7 @@ def update_fn(updates, state, params=None): v_hat = optax.tree.bias_correction(v, b2, t) n_hat = optax.tree.bias_correction(n, b3, t) - u = optax.tree.add_scale(m_hat, 1 - b2, v_hat) + u = optax.tree.add_scale(m_hat, b2, v_hat) denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat) u = optax.tree.div(u, denom) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index f521045ff..084d3e751 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -73,6 +73,39 @@ def test_scalers(self, scaler_constr): test_utils.assert_tree_all_finite((params, updates, state)) test_utils.assert_trees_all_equal_shapes(params, updates) + def test_adan_matches_documented_update(self): + # Reference implementation of the equations documented in `optax.adan`, + # where the gradient-difference term in both `n` and the update is scaled + # by (1 - beta_2), which equals `b2` in optax's (1 - beta) parameterization. + b1, b2, b3 = 0.98, 0.92, 0.99 + eps, eps_root = 1e-8, 0.0 + grads = [ + jnp.array([0.5, -1.5, 2.0]), + jnp.array([0.3, -1.0, 2.5]), + jnp.array([-0.2, 0.7, 1.0]), + ] + + tx = transform.scale_by_adan(b1=b1, b2=b2, b3=b3, eps=eps, eps_root=eps_root) + state = tx.init(grads[0]) + + m = jnp.zeros(3) + v = jnp.zeros(3) + n = jnp.zeros(3) + g_prev = jnp.zeros(3) + for step, g in enumerate(grads, start=1): + diff = jnp.zeros(3) if step == 1 else g - g_prev + m = b1 * m + (1 - b1) * g + v = b2 * v + (1 - b2) * diff + n = b3 * n + (1 - b3) * (g + b2 * diff) ** 2 + m_hat = m / (1 - b1**step) + v_hat = v / (1 - b2**step) + n_hat = n / (1 - b3**step) + expected = (m_hat + b2 * v_hat) / (jnp.sqrt(n_hat + eps_root) + eps) + g_prev = g + + updates, state = tx.update(g, state, None) + test_utils.assert_trees_all_close(updates, expected, atol=1e-4, rtol=1e-4) + def test_apply_every(self): # The frequency of the application of sgd k = 4 From 1fd634046adf99e522d9e896c55343fa3144e6be Mon Sep 17 00:00:00 2001 From: Ali Zulfiqar Date: Sat, 27 Jun 2026 00:12:53 +0500 Subject: [PATCH 2/4] Wrap scale_by_adan call in test to satisfy 80-col lint Fixes the ruff E501 / pyink pre-commit failure on the line exceeding the 80-character limit. --- optax/_src/transform_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 084d3e751..f3261478d 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -85,7 +85,9 @@ def test_adan_matches_documented_update(self): jnp.array([-0.2, 0.7, 1.0]), ] - tx = transform.scale_by_adan(b1=b1, b2=b2, b3=b3, eps=eps, eps_root=eps_root) + tx = transform.scale_by_adan( + b1=b1, b2=b2, b3=b3, eps=eps, eps_root=eps_root + ) state = tx.init(grads[0]) m = jnp.zeros(3) From eed58f1982c0ef67f6b4e93d8fd34fc15781d12a Mon Sep 17 00:00:00 2001 From: Ali Zulfiqar Date: Sat, 27 Jun 2026 10:49:16 +0500 Subject: [PATCH 3/4] Update adan docstring example for corrected update The coefficient fix slightly changes the trajectory printed by the runnable example in the `optax.adan` docstring. Update the last two expected objective values (9.68E+00 -> 9.69E+00, 8.76E+00 -> 8.77E+00) to match, fixing the doctest. Reported by @rdyro in review. --- optax/_src/alias.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 4b6ff23d2..cf8eb97bc 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -921,8 +921,8 @@ def adan( Objective function: 1.28E+01 Objective function: 1.17E+01 Objective function: 1.07E+01 - Objective function: 9.68E+00 - Objective function: 8.76E+00 + Objective function: 9.69E+00 + Objective function: 8.77E+00 References: Xie et al, `Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing From e0bea072bf9523d8b8b488e0406736e457c140c0 Mon Sep 17 00:00:00 2001 From: Ali Zulfiqar Date: Sat, 27 Jun 2026 11:07:58 +0500 Subject: [PATCH 4/4] Adopt paper beta convention for adan per review Keep the documented (1 - b2) form in the update and the squared term, and instead set the constructor defaults to the paper's beta values (b1=0.02, b2=0.08, b3=0.01), flipping the moment and bias-correction decays to (1 - b_i) so the EMAs match the documented equations. The optimizer behavior is unchanged; only the parameterization exposed to users matches the paper. Update the test to use the new defaults. Addresses @rdyro's review feedback. --- optax/_src/alias.py | 6 +++--- optax/_src/transform.py | 22 +++++++++++----------- optax/_src/transform_test.py | 20 +++++++++----------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index cf8eb97bc..7c8da43a2 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -826,9 +826,9 @@ def adamw( def adan( learning_rate: base.ScalarOrSchedule, - b1: jax.typing.ArrayLike = 0.98, - b2: jax.typing.ArrayLike = 0.92, - b3: jax.typing.ArrayLike = 0.99, + b1: jax.typing.ArrayLike = 0.02, + b2: jax.typing.ArrayLike = 0.08, + b3: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 1e-8, weight_decay: base.ScalarOrSchedule = 0.0, diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 082cb5938..0d46aba76 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -621,9 +621,9 @@ class ScaleByAdanState(NamedTuple): def scale_by_adan( - b1: jax.typing.ArrayLike = 0.98, - b2: jax.typing.ArrayLike = 0.92, - b3: jax.typing.ArrayLike = 0.99, + b1: jax.typing.ArrayLike = 0.02, + b2: jax.typing.ArrayLike = 0.08, + b3: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 0.0, ) -> base.GradientTransformation: @@ -662,18 +662,18 @@ def update_fn(updates, state, params=None): optax.tree.zeros_like(g), optax.tree.sub(g, state.g), ) - m = optax.tree.update_moment(g, state.m, b1, 1) - v = optax.tree.update_moment(diff, state.v, b2, 1) + m = optax.tree.update_moment(g, state.m, 1 - b1, 1) + v = optax.tree.update_moment(diff, state.v, 1 - b2, 1) - sq = optax.tree.add_scale(g, b2, diff) - n = optax.tree.update_moment_per_elem_norm(sq, state.n, b3, 2) + sq = optax.tree.add_scale(g, 1 - b2, diff) + n = optax.tree.update_moment_per_elem_norm(sq, state.n, 1 - b3, 2) t = numerics.safe_increment(state.t) - m_hat = optax.tree.bias_correction(m, b1, t) - v_hat = optax.tree.bias_correction(v, b2, t) - n_hat = optax.tree.bias_correction(n, b3, t) + m_hat = optax.tree.bias_correction(m, 1 - b1, t) + v_hat = optax.tree.bias_correction(v, 1 - b2, t) + n_hat = optax.tree.bias_correction(n, 1 - b3, t) - u = optax.tree.add_scale(m_hat, b2, v_hat) + u = optax.tree.add_scale(m_hat, 1 - b2, v_hat) denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat) u = optax.tree.div(u, denom) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index f3261478d..00340f4bf 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -74,10 +74,8 @@ def test_scalers(self, scaler_constr): test_utils.assert_trees_all_equal_shapes(params, updates) def test_adan_matches_documented_update(self): - # Reference implementation of the equations documented in `optax.adan`, - # where the gradient-difference term in both `n` and the update is scaled - # by (1 - beta_2), which equals `b2` in optax's (1 - beta) parameterization. - b1, b2, b3 = 0.98, 0.92, 0.99 + # Reference implementation of the equations documented in `optax.adan`. + b1, b2, b3 = 0.02, 0.08, 0.01 eps, eps_root = 1e-8, 0.0 grads = [ jnp.array([0.5, -1.5, 2.0]), @@ -96,13 +94,13 @@ def test_adan_matches_documented_update(self): g_prev = jnp.zeros(3) for step, g in enumerate(grads, start=1): diff = jnp.zeros(3) if step == 1 else g - g_prev - m = b1 * m + (1 - b1) * g - v = b2 * v + (1 - b2) * diff - n = b3 * n + (1 - b3) * (g + b2 * diff) ** 2 - m_hat = m / (1 - b1**step) - v_hat = v / (1 - b2**step) - n_hat = n / (1 - b3**step) - expected = (m_hat + b2 * v_hat) / (jnp.sqrt(n_hat + eps_root) + eps) + m = (1 - b1) * m + b1 * g + v = (1 - b2) * v + b2 * diff + n = (1 - b3) * n + b3 * (g + (1 - b2) * diff) ** 2 + m_hat = m / (1 - (1 - b1) ** step) + v_hat = v / (1 - (1 - b2) ** step) + n_hat = n / (1 - (1 - b3) ** step) + expected = (m_hat + (1 - b2) * v_hat) / (jnp.sqrt(n_hat + eps_root) + eps) g_prev = g updates, state = tx.update(g, state, None)