Skip to content

Fix incorrect (1 - b2) coefficient in scale_by_adan#1715

Open
codewithfourtix wants to merge 4 commits into
google-deepmind:mainfrom
codewithfourtix:fix-adan-beta2-coefficient
Open

Fix incorrect (1 - b2) coefficient in scale_by_adan#1715
codewithfourtix wants to merge 4 commits into
google-deepmind:mainfrom
codewithfourtix:fix-adan-beta2-coefficient

Conversation

@codewithfourtix

Copy link
Copy Markdown

Summary

scale_by_adan scales the gradient-difference term by 1 - b2 in two
places, but the algorithm documented in optax.adan scales it by
(1 - beta_2), which equals b2 in optax's parameterization. With the
default b2=0.92 the optimizer applied a weight of 0.08 instead of
0.92, diverging from the documented equations from the second step onward.

Details

optax parameterizes Adan's moment EMAs via update_moment(..., decay=b2),
i.e. v_t = b2 * v_{t-1} + (1 - b2) * (g_t - g_{t-1}). Matching this to the
documented equation v_t = (1 - beta_2) v_{t-1} + beta_2 (g_t - g_{t-1})
forces beta_2 = 1 - b2. Therefore the documented coefficient (1 - beta_2)
that appears in:

  • the squared accumulator n_t = ... (g_t + (1 - beta_2)(g_t - g_{t-1}))^2
  • the update u_t = m_t + (1 - beta_2) v_t

equals b2, not 1 - b2. This also matches Algorithm 1 of
Xie et al., 2022 (https://arxiv.org/abs/2208.06677), where the term is
(1 - beta_2).

The first step is unaffected (the gradient difference is zero), which is why
the existing "loss decreases" tests did not catch this.

Changes

  • scale_by_adan: use b2 instead of 1 - b2 for the gradient-difference
    term in both the n accumulator and the update direction.
  • Add test_adan_matches_documented_update, which checks scale_by_adan
    against an independent reference implementation of the documented equations
    over several steps.

Compatibility

This changes the numerical behavior of optax.adan to match its documented
algorithm and the reference paper. Existing code using optax.adan will now
follow the intended Adan update.

The Adan update documented in `optax.adan` scales the gradient-difference
term in both the squared accumulator `n` and the update direction by
(1 - beta_2). optax parameterizes the moment EMAs via `update_moment` with
decay `b2`, so b2 == (1 - beta_2); the two `add_scale` calls must therefore
use `b2`, not `1 - b2`. With the default b2=0.92 the previous code applied a
weight of 0.08 instead of 0.92, diverging from both the documented equations
and the reference algorithm (Xie et al., 2022) from the second step onward.

Add a test comparing scale_by_adan against an independent reference
implementation of the documented equations over several steps.
Fixes the ruff E501 / pyink pre-commit failure on the line exceeding the
80-character limit.

@rdyro rdyro left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, you seem to have caught a genuine problem with the optax implementation!

Comment thread optax/_src/transform.py
m = optax.tree.update_moment(g, state.m, b1, 1)
v = optax.tree.update_moment(diff, state.v, b2, 1)

sq = optax.tree.add_scale(g, 1 - b2, diff)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer optax stick closer to the math and we keep 1 - b2, but update the constructor defaults.

Comment thread optax/_src/transform_test.py Outdated
# Reference implementation of the equations documented in `optax.adan`,
# where the gradient-difference term in both `n` and the update is scaled
# by (1 - beta_2), which equals `b2` in optax's (1 - beta) parameterization.
b1, b2, b3 = 0.98, 0.92, 0.99

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above, use 0.02, etc.

@rdyro

rdyro commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

You'll need to update the docstring in adan docstring in alias.py too:

Document: api/generated/optax.adan
----------------------------------
**********************************************************************
File "../optax/_src/alias.py", line ?, in default
Failed example:
    for _ in range(5):
     grad = jax.grad(f)(params)
     updates, opt_state = solver.update(grad, opt_state, params)
     params = optax.apply_updates(params, updates)
     print('Objective function: {:.2E}'.format(f(params)))
Expected:
    Objective function: 1.28E+01
    Objective function: 1.17E+01
    Objective function: 1.07E+01
    Objective function: 9.68E+00
    Objective function: 8.76E+00
Got:
    Objective function: 1.28E+01
    Objective function: 1.17E+01
    Objective function: 1.07E+01
    Objective function: 9.69E+00
    Objective function: 8.77E+00

The coefficient fix slightly changes the trajectory printed by the runnable
example in the `optax.adan` docstring. Update the last two expected objective
values (9.68E+00 -> 9.69E+00, 8.76E+00 -> 8.77E+00) to match, fixing the
doctest. Reported by @rdyro in review.
Keep the documented (1 - b2) form in the update and the squared term, and
instead set the constructor defaults to the paper's beta values
(b1=0.02, b2=0.08, b3=0.01), flipping the moment and bias-correction decays
to (1 - b_i) so the EMAs match the documented equations. The optimizer
behavior is unchanged; only the parameterization exposed to users matches
the paper. Update the test to use the new defaults.

Addresses @rdyro's review feedback.
@codewithfourtix

Copy link
Copy Markdown
Author

@rdyro Done, kept the 1 - b2 form and switched the defaults to the paper's betas (b1=0.02, b2=0.08, b3=0.01), flipping the moment/bias-correction decays to 1 - b_i so the EMAs stay consistent. Behavior is unchanged, the parameterization now matches the paper. Updated the test to the new defaults and the docstring example. Let me know if you'd like anything else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants