Skip to content

feat(contrib): add cautious optimizer wrapper#1690

Open
Sumu004 wants to merge 5 commits into
google-deepmind:mainfrom
Sumu004:feat/cautious-optimizer
Open

feat(contrib): add cautious optimizer wrapper#1690
Sumu004 wants to merge 5 commits into
google-deepmind:mainfrom
Sumu004:feat/cautious-optimizer

Conversation

@Sumu004

@Sumu004 Sumu004 commented Jun 5, 2026

Copy link
Copy Markdown

Summary

  • Add optax.contrib.cautious, implementing the Cautious Optimizer (Liang et al., 2024)
  • A wrapper around any base optimizer that masks the coordinates of the update which would move against the current gradient (locally increasing the loss), then rescales the survivors so mean update magnitude is preserved
  • One-line modification with a provable descent guarantee; the authors report up to 1.47× sample efficiency wrapping AdamW for LLM/ViT pre-training

Note: this is distinct from #1611 (Cautious Weight Decay, Chen et al. 2025), which masks the weight-decay term. This PR masks the optimizer update itself (Liang et al. 2024).

The mechanism

Let u_t be the base optimizer's proposed update (Optax's additive convention params <- params + u_t) and g_t the gradient. The cautious mask keeps only descent-aligned coordinates and rescales per parameter tensor:

phi_t = 1[u_t * g_t < 0]                       # keep where update opposes gradient
u_t  <- phi_t * u_t * n / (sum(phi_t) + eps)   # mean-preserving rescale

The condition u_t * g_t < 0 is the paper's alignment condition (-u_t) * g_t > 0 re-expressed in Optax's additive convention (Optax adds the update, the paper subtracts it).

Descent guarantee: the cautious update always satisfies <u_t, g_t> <= 0, so it never points uphill — whereas a momentum-based optimizer can overshoot and do so. This preserves the base optimizer's Hamiltonian/Lyapunov descent.

Changes

File Description
optax/contrib/_cautious.py New cautious wrapper + CautiousState
optax/contrib/_cautious_test.py 13 unit tests (descent guarantee, exact mask/rescale, jit, pytree, wraps adam/adamw/lion/sgd)
optax/contrib/__init__.py Export cautious, CautiousState
optax/contrib/_common_test.py Register cautious in the shared wrapper test suite
docs/api/contrib.rst Document the new API

Design notes

  • Implemented as a wrapper (like schedule_free) rather than a chainable scale_by_* transform, because the mask needs both the raw gradient and the base optimizer's proposed update — information that is lost once you're downstream in a chain.
  • eps=1e-8 (default) makes the wrapper reduce exactly to the base optimizer when every coordinate agrees with the gradient (the mean-preserving normalization). The original paper uses eps=1; pass eps=1.0 to match it. Both are unit-tested.

Empirical analysis

I ran a benchmark on an ill-conditioned, noisy least-squares task (κ≈50, heavy gradient noise) that stresses momentum overshoot, averaged over 3 seeds:

Cautious benchmark

Descent guarantee (uphill steps where <update, grad> > 0, over 400 steps):

Optimizer Uphill steps
SGD+momentum 33
Cautious-SGD+momentum 0
AdamW 6
Cautious-AdamW 0

The bottom-left panel shows the base momentum optimizer repeatedly crossing into the uphill region (<update,grad> > 0), while the cautious variant stays strictly in the descent region — exactly the theoretical property the paper proves. Mean mask density was ~0.6–0.65 (about a third of coordinates masked per step).

Test plan

  • pytest optax/contrib/_cautious_test.py — 13 passed
  • pytest --doctest-modules optax/contrib/_cautious.py — doctest passes
  • ruff check — clean
  • Registered in optax/contrib/_common_test.py shared suite (collects 593 tests)

References

Liang et al., Cautious Optimizers: Improving Training with One Line of Code, 2024.

Sumu004 added 5 commits June 5, 2026 16:22
Implements the Cautious optimizer from Liang et al. (2024):
  https://arxiv.org/abs/2411.16085

`optax.contrib.cautious(base_optimizer)` wraps any optimizer and, on
every step, zeroes the coordinates of the proposed update that would
move against the current gradient (locally increasing the loss), then
rescales the surviving coordinates per parameter tensor so the mean
update magnitude is preserved:

  phi_t = 1[u_t * g_t < 0]                  # descent-aligned mask
  u_t  <- phi_t * u_t * n / (sum(phi_t)+eps)

The mask condition `u_t * g_t < 0` is the paper's alignment condition
`(-u_t) * g_t > 0` re-expressed in Optax's additive update convention
(Optax adds the update; the paper subtracts it).

This single-line modification provably preserves the Hamiltonian /
Lyapunov descent of the base optimizer: the cautious update always
satisfies <u_t, g_t> <= 0, so it never points uphill, whereas a
momentum-based optimizer can. The authors report up to 1.47x
sample-efficiency gains wrapping AdamW for LLM/ViT pre-training.

Implemented as a wrapper (like schedule_free) because the mask needs
both the raw gradient and the base optimizer's proposed update.

New public API:
  - optax.contrib.cautious       -- the wrapper
  - optax.contrib.CautiousState  -- NamedTuple holding base state

Features:
  - eps=1e-8 default makes it reduce exactly to the base optimizer
    when all coordinates agree; eps=1.0 matches the paper's damping
  - Full pytree support, jit-compatible, GradientTransformationExtraArgs
  - 13 unit tests incl. the descent-guarantee property and exact
    mask/rescale checks
  - Registered as a wrapper in _common_test.py and documented in
    docs/api/contrib.rst
- Rename update_fn first param grads -> updates to match the protocol's
  positional parameter name
- Annotate params as Optional[base.Params]
- Suppress the NamedTuple-state variance false-positive with
  # pyrefly: ignore[bad-argument-type], matching schedule_free / mechanize
- Replace chex assertions with optax._src.test_utils equivalents; chex is
  not available in the contrib CI test environment (caused an ImportError
  on the jax=0.5.3 job)
- Change the cautious common-test entry from sgd(1e-2) to
  sgd(1e-3, momentum=0.9): plain sgd diverges on Rosenbrock at higher lrs,
  and momentum also exercises the cautious mask (momentum can disagree with
  the current gradient). Verified convergence on all four shared targets
  (parabola, rosenbrock, matrix-parabola, mixed-tensor) within 30k steps
…nightly

JAX nightly changes complex64 accumulation order, causing a 1.3e-7
relative error in tree_vdot for complex inputs — just above the default
rtol=1e-7. Loosen to rtol=1e-6, matching the same pre-existing fix on
the MARS PR (google-deepmind#1689).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant