Fix incorrect (1 - b2) coefficient in scale_by_adan#1715
Fix incorrect (1 - b2) coefficient in scale_by_adan#1715codewithfourtix wants to merge 4 commits into
Conversation
The Adan update documented in `optax.adan` scales the gradient-difference term in both the squared accumulator `n` and the update direction by (1 - beta_2). optax parameterizes the moment EMAs via `update_moment` with decay `b2`, so b2 == (1 - beta_2); the two `add_scale` calls must therefore use `b2`, not `1 - b2`. With the default b2=0.92 the previous code applied a weight of 0.08 instead of 0.92, diverging from both the documented equations and the reference algorithm (Xie et al., 2022) from the second step onward. Add a test comparing scale_by_adan against an independent reference implementation of the documented equations over several steps.
Fixes the ruff E501 / pyink pre-commit failure on the line exceeding the 80-character limit.
rdyro
left a comment
There was a problem hiding this comment.
Thanks, you seem to have caught a genuine problem with the optax implementation!
| m = optax.tree.update_moment(g, state.m, b1, 1) | ||
| v = optax.tree.update_moment(diff, state.v, b2, 1) | ||
|
|
||
| sq = optax.tree.add_scale(g, 1 - b2, diff) |
There was a problem hiding this comment.
I'd prefer optax stick closer to the math and we keep 1 - b2, but update the constructor defaults.
| # Reference implementation of the equations documented in `optax.adan`, | ||
| # where the gradient-difference term in both `n` and the update is scaled | ||
| # by (1 - beta_2), which equals `b2` in optax's (1 - beta) parameterization. | ||
| b1, b2, b3 = 0.98, 0.92, 0.99 |
There was a problem hiding this comment.
See comment above, use 0.02, etc.
|
You'll need to update the docstring in adan docstring in alias.py too: |
The coefficient fix slightly changes the trajectory printed by the runnable example in the `optax.adan` docstring. Update the last two expected objective values (9.68E+00 -> 9.69E+00, 8.76E+00 -> 8.77E+00) to match, fixing the doctest. Reported by @rdyro in review.
Keep the documented (1 - b2) form in the update and the squared term, and instead set the constructor defaults to the paper's beta values (b1=0.02, b2=0.08, b3=0.01), flipping the moment and bias-correction decays to (1 - b_i) so the EMAs match the documented equations. The optimizer behavior is unchanged; only the parameterization exposed to users matches the paper. Update the test to use the new defaults. Addresses @rdyro's review feedback.
|
@rdyro Done, kept the 1 - b2 form and switched the defaults to the paper's betas (b1=0.02, b2=0.08, b3=0.01), flipping the moment/bias-correction decays to 1 - b_i so the EMAs stay consistent. Behavior is unchanged, the parameterization now matches the paper. Updated the test to the new defaults and the docstring example. Let me know if you'd like anything else. |
Summary
scale_by_adanscales the gradient-difference term by1 - b2in twoplaces, but the algorithm documented in
optax.adanscales it by(1 - beta_2), which equalsb2in optax's parameterization. With thedefault
b2=0.92the optimizer applied a weight of0.08instead of0.92, diverging from the documented equations from the second step onward.Details
optax parameterizes Adan's moment EMAs via
update_moment(..., decay=b2),i.e.
v_t = b2 * v_{t-1} + (1 - b2) * (g_t - g_{t-1}). Matching this to thedocumented equation
v_t = (1 - beta_2) v_{t-1} + beta_2 (g_t - g_{t-1})forces
beta_2 = 1 - b2. Therefore the documented coefficient(1 - beta_2)that appears in:
n_t = ... (g_t + (1 - beta_2)(g_t - g_{t-1}))^2u_t = m_t + (1 - beta_2) v_tequals
b2, not1 - b2. This also matches Algorithm 1 ofXie et al., 2022 (https://arxiv.org/abs/2208.06677), where the term is
(1 - beta_2).The first step is unaffected (the gradient difference is zero), which is why
the existing "loss decreases" tests did not catch this.
Changes
scale_by_adan: useb2instead of1 - b2for the gradient-differenceterm in both the
naccumulator and the update direction.test_adan_matches_documented_update, which checksscale_by_adanagainst an independent reference implementation of the documented equations
over several steps.
Compatibility
This changes the numerical behavior of
optax.adanto match its documentedalgorithm and the reference paper. Existing code using
optax.adanwill nowfollow the intended Adan update.