feat(contrib): add cautious optimizer wrapper#1690
Open
Sumu004 wants to merge 5 commits into
Open
Conversation
Implements the Cautious optimizer from Liang et al. (2024): https://arxiv.org/abs/2411.16085 `optax.contrib.cautious(base_optimizer)` wraps any optimizer and, on every step, zeroes the coordinates of the proposed update that would move against the current gradient (locally increasing the loss), then rescales the surviving coordinates per parameter tensor so the mean update magnitude is preserved: phi_t = 1[u_t * g_t < 0] # descent-aligned mask u_t <- phi_t * u_t * n / (sum(phi_t)+eps) The mask condition `u_t * g_t < 0` is the paper's alignment condition `(-u_t) * g_t > 0` re-expressed in Optax's additive update convention (Optax adds the update; the paper subtracts it). This single-line modification provably preserves the Hamiltonian / Lyapunov descent of the base optimizer: the cautious update always satisfies <u_t, g_t> <= 0, so it never points uphill, whereas a momentum-based optimizer can. The authors report up to 1.47x sample-efficiency gains wrapping AdamW for LLM/ViT pre-training. Implemented as a wrapper (like schedule_free) because the mask needs both the raw gradient and the base optimizer's proposed update. New public API: - optax.contrib.cautious -- the wrapper - optax.contrib.CautiousState -- NamedTuple holding base state Features: - eps=1e-8 default makes it reduce exactly to the base optimizer when all coordinates agree; eps=1.0 matches the paper's damping - Full pytree support, jit-compatible, GradientTransformationExtraArgs - 13 unit tests incl. the descent-guarantee property and exact mask/rescale checks - Registered as a wrapper in _common_test.py and documented in docs/api/contrib.rst
- Rename update_fn first param grads -> updates to match the protocol's positional parameter name - Annotate params as Optional[base.Params] - Suppress the NamedTuple-state variance false-positive with # pyrefly: ignore[bad-argument-type], matching schedule_free / mechanize
- Replace chex assertions with optax._src.test_utils equivalents; chex is not available in the contrib CI test environment (caused an ImportError on the jax=0.5.3 job) - Change the cautious common-test entry from sgd(1e-2) to sgd(1e-3, momentum=0.9): plain sgd diverges on Rosenbrock at higher lrs, and momentum also exercises the cautious mask (momentum can disagree with the current gradient). Verified convergence on all four shared targets (parabola, rosenbrock, matrix-parabola, mixed-tensor) within 30k steps
…nightly JAX nightly changes complex64 accumulation order, causing a 1.3e-7 relative error in tree_vdot for complex inputs — just above the default rtol=1e-7. Loosen to rtol=1e-6, matching the same pre-existing fix on the MARS PR (google-deepmind#1689).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
optax.contrib.cautious, implementing the Cautious Optimizer (Liang et al., 2024)The mechanism
Let
u_tbe the base optimizer's proposed update (Optax's additive conventionparams <- params + u_t) andg_tthe gradient. The cautious mask keeps only descent-aligned coordinates and rescales per parameter tensor:The condition
u_t * g_t < 0is the paper's alignment condition(-u_t) * g_t > 0re-expressed in Optax's additive convention (Optax adds the update, the paper subtracts it).Descent guarantee: the cautious update always satisfies
<u_t, g_t> <= 0, so it never points uphill — whereas a momentum-based optimizer can overshoot and do so. This preserves the base optimizer's Hamiltonian/Lyapunov descent.Changes
optax/contrib/_cautious.pycautiouswrapper +CautiousStateoptax/contrib/_cautious_test.pyoptax/contrib/__init__.pycautious,CautiousStateoptax/contrib/_common_test.pycautiousin the shared wrapper test suitedocs/api/contrib.rstDesign notes
schedule_free) rather than a chainablescale_by_*transform, because the mask needs both the raw gradient and the base optimizer's proposed update — information that is lost once you're downstream in achain.eps=1e-8(default) makes the wrapper reduce exactly to the base optimizer when every coordinate agrees with the gradient (the mean-preserving normalization). The original paper useseps=1; passeps=1.0to match it. Both are unit-tested.Empirical analysis
I ran a benchmark on an ill-conditioned, noisy least-squares task (κ≈50, heavy gradient noise) that stresses momentum overshoot, averaged over 3 seeds:
Descent guarantee (uphill steps where
<update, grad> > 0, over 400 steps):The bottom-left panel shows the base momentum optimizer repeatedly crossing into the uphill region (
<update,grad> > 0), while the cautious variant stays strictly in the descent region — exactly the theoretical property the paper proves. Mean mask density was ~0.6–0.65 (about a third of coordinates masked per step).Test plan
pytest optax/contrib/_cautious_test.py— 13 passedpytest --doctest-modules optax/contrib/_cautious.py— doctest passesruff check— cleanoptax/contrib/_common_test.pyshared suite (collects 593 tests)References
Liang et al., Cautious Optimizers: Improving Training with One Line of Code, 2024.