Skip to content

Fix ACProp second moment to use the new first moment#1716

Open
codewithfourtix wants to merge 1 commit into
google-deepmind:mainfrom
codewithfourtix:fix-acprop-residual-moment
Open

Fix ACProp second moment to use the new first moment#1716
codewithfourtix wants to merge 1 commit into
google-deepmind:mainfrom
codewithfourtix:fix-acprop-residual-moment

Conversation

@codewithfourtix

Copy link
Copy Markdown

Summary

scale_by_acprop computes its second-moment residual against state.mu
(the previous-step first moment m_{t-1}) instead of the freshly-updated
mu (m_t), contradicting the documented update rule.

Details

The optax.contrib.acprop docstring documents the second moment as:

s_t = b2 * s_{t-1} + (1 - b2) * (g_t - m_t) ** 2 + eps_root

where m_t is the first moment after the step-t update. The code already
computes that new moment on the line above (mu = update_moment(...)), but the
residual is then computed as g - state.mu, i.e. g_t - m_{t-1}.

The new moment is the correct one on three independent grounds:

  • the docstring's own equation uses (g_t - m_t);
  • the original PyTorch implementation the docstring follows
    (juntang-zhuang/ACProp-Optimizer) updates exp_avg first and then computes
    grad - exp_avg;
  • optax's own AdaBelief sibling scale_by_belief uses the freshly-updated mu.

ACProp's asynchrony concerns which second moment feeds the denominator
(s_{t-1}), which the code already implements correctly; it does not affect
the residual's mean term.

Example: with b1=0.9, b2=0.999, eps_root=0, a first step of g=1.0 should
give s_1 = (1 - b2) * (g - m_1)**2 = 0.001 * 0.9**2 = 0.00081, but the
current code yields 0.001 (using g - m_0 = 1.0).

Changes

  • scale_by_acprop: use the updated mu in the prediction-error residual.
  • Add _acprop_test.py pinning the second moment to the documented
    (g_t - m_t) residual over several steps.

ACProp's documented second moment is
  s_t = b2 * s_{t-1} + (1 - b2) * (g_t - m_t) ** 2 + eps_root,
where m_t is the first moment *after* the step-t update (see the equations in
`optax.contrib.acprop`). The residual was instead computed against `state.mu`,
the previous-step moment m_{t-1}, rather than the freshly-updated `mu` (m_t).

The new moment is the correct one: it matches the docstring formula, the
original PyTorch implementation the docstring follows
(juntang-zhuang/ACProp-Optimizer, which updates exp_avg before computing
grad - exp_avg), and optax's own scale_by_belief sibling.

Add a test pinning the second moment to the documented (g_t - m_t) residual.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant