Fix ACProp second moment to use the new first moment#1716
Open
codewithfourtix wants to merge 1 commit into
Open
Fix ACProp second moment to use the new first moment#1716codewithfourtix wants to merge 1 commit into
codewithfourtix wants to merge 1 commit into
Conversation
ACProp's documented second moment is
s_t = b2 * s_{t-1} + (1 - b2) * (g_t - m_t) ** 2 + eps_root,
where m_t is the first moment *after* the step-t update (see the equations in
`optax.contrib.acprop`). The residual was instead computed against `state.mu`,
the previous-step moment m_{t-1}, rather than the freshly-updated `mu` (m_t).
The new moment is the correct one: it matches the docstring formula, the
original PyTorch implementation the docstring follows
(juntang-zhuang/ACProp-Optimizer, which updates exp_avg before computing
grad - exp_avg), and optax's own scale_by_belief sibling.
Add a test pinning the second moment to the documented (g_t - m_t) residual.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
scale_by_acpropcomputes its second-moment residual againststate.mu(the previous-step first moment
m_{t-1}) instead of the freshly-updatedmu(m_t), contradicting the documented update rule.Details
The
optax.contrib.acpropdocstring documents the second moment as:where
m_tis the first moment after the step-t update. The code alreadycomputes that new moment on the line above (
mu = update_moment(...)), but theresidual is then computed as
g - state.mu, i.e.g_t - m_{t-1}.The new moment is the correct one on three independent grounds:
(g_t - m_t);(juntang-zhuang/ACProp-Optimizer) updates
exp_avgfirst and then computesgrad - exp_avg;scale_by_beliefuses the freshly-updatedmu.ACProp's asynchrony concerns which second moment feeds the denominator
(
s_{t-1}), which the code already implements correctly; it does not affectthe residual's mean term.
Example: with
b1=0.9, b2=0.999, eps_root=0, a first step ofg=1.0shouldgive
s_1 = (1 - b2) * (g - m_1)**2 = 0.001 * 0.9**2 = 0.00081, but thecurrent code yields
0.001(usingg - m_0 = 1.0).Changes
scale_by_acprop: use the updatedmuin the prediction-error residual._acprop_test.pypinning the second moment to the documented(g_t - m_t)residual over several steps.