From dc2fb1646d17e694023c27ffa4097d73906457cb Mon Sep 17 00:00:00 2001 From: Muneerali199 Date: Tue, 9 Jun 2026 23:52:48 +0530 Subject: [PATCH] fix: memoize optimizer factory functions to avoid JIT recompilation (#353) When GradientTransformation objects are passed as static arguments to jax.jit, JAX recompiles on every call because the closures inside each GradientTransformation have different identities, producing different hashes. Fix: add @functools.lru_cache(maxsize=None) to all 29 optimizer alias factory functions in alias.py, plus chain() in _combining.py, and identity/set_to_zero/stateless/with_extra_args_support in base.py. Memoization ensures the same arguments always return the exact same GradientTransformation object, with stable identity-based hashing. JAX sees the same static argument and does not recompile. Closes #353 --- optax/_src/alias.py | 29 +++++++++++++ optax/_src/alias_test.py | 74 ++++++++++++++++++++++++++++++++++ optax/_src/base.py | 6 +++ optax/transforms/_combining.py | 2 + 4 files changed, 111 insertions(+) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 4b6ff23d2..fce978f85 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -33,6 +33,7 @@ MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] +@functools.lru_cache(maxsize=None) def adabelief( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -157,6 +158,7 @@ def adabelief( return combine.chain(*chain_args) +@functools.lru_cache(maxsize=None) def adadelta( learning_rate: Optional[base.ScalarOrSchedule] = None, rho: jax.typing.ArrayLike = 0.9, @@ -239,6 +241,7 @@ def adadelta( ) +@functools.lru_cache(maxsize=None) def adafactor( learning_rate: Optional[base.ScalarOrSchedule] = None, min_dim_size_to_factor: int = 128, @@ -344,6 +347,7 @@ def adafactor( return combine.chain(*tx) +@functools.lru_cache(maxsize=None) def adagrad( learning_rate: base.ScalarOrSchedule, initial_accumulator_value: jax.typing.ArrayLike = 0.1, @@ -429,6 +433,7 @@ def adagrad( ) +@functools.lru_cache(maxsize=None) def adam( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -613,6 +618,7 @@ def adam( """ +@functools.lru_cache(maxsize=None) def adamw( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -824,6 +830,7 @@ def adamw( ) +@functools.lru_cache(maxsize=None) def adan( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.98, @@ -942,6 +949,7 @@ def adan( ) +@functools.lru_cache(maxsize=None) def lion( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -1038,6 +1046,7 @@ def lion( ) +@functools.lru_cache(maxsize=None) def amsgrad( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -1113,6 +1122,7 @@ def amsgrad( ) +@functools.lru_cache(maxsize=None) def fromage( learning_rate: base.ScalarOrSchedule, min_norm: jax.typing.ArrayLike = 1e-6 ) -> base.GradientTransformationExtraArgs: @@ -1176,6 +1186,7 @@ def fromage( ) +@functools.lru_cache(maxsize=None) def lars( learning_rate: base.ScalarOrSchedule, weight_decay: base.ScalarOrSchedule = 0.0, @@ -1249,6 +1260,7 @@ def lars( ) +@functools.lru_cache(maxsize=None) def lamb( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -1318,6 +1330,7 @@ def lamb( ) +@functools.lru_cache(maxsize=None) def noisy_sgd( learning_rate: base.ScalarOrSchedule, eta: jax.typing.ArrayLike = 0.01, @@ -1393,6 +1406,7 @@ def noisy_sgd( ) +@functools.lru_cache(maxsize=None) def sign_sgd( learning_rate: base.ScalarOrSchedule, ) -> base.GradientTransformationExtraArgs: @@ -1457,6 +1471,7 @@ def sign_sgd( ) +@functools.lru_cache(maxsize=None) def signum( learning_rate: base.ScalarOrSchedule, beta: jax.typing.ArrayLike = 0.9, @@ -1497,6 +1512,7 @@ def signum( ) +@functools.lru_cache(maxsize=None) def novograd( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -1572,6 +1588,7 @@ def novograd( ) +@functools.lru_cache(maxsize=None) def optimistic_gradient_descent( learning_rate: base.ScalarOrSchedule, alpha: base.ScalarOrSchedule = 1.0, @@ -1640,6 +1657,7 @@ def optimistic_gradient_descent( ) +@functools.lru_cache(maxsize=None) def optimistic_adam( learning_rate: jax.typing.ArrayLike, optimism: Optional[jax.typing.ArrayLike] = None, @@ -1764,6 +1782,7 @@ def optimistic_adam( ) +@functools.lru_cache(maxsize=None) def optimistic_adam_v2( learning_rate: base.ScalarOrSchedule, *, @@ -1886,6 +1905,7 @@ def optimistic_adam_v2( ) +@functools.lru_cache(maxsize=None) def radam( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -1957,6 +1977,7 @@ def radam( ) +@functools.lru_cache(maxsize=None) def rmsprop( learning_rate: base.ScalarOrSchedule, decay: jax.typing.ArrayLike = 0.9, @@ -2071,6 +2092,7 @@ def rmsprop( ) +@functools.lru_cache(maxsize=None) def sgd( learning_rate: base.ScalarOrSchedule, momentum: Optional[jax.typing.ArrayLike] = None, @@ -2161,6 +2183,7 @@ def sgd( ) +@functools.lru_cache(maxsize=None) def sm3( learning_rate: jax.typing.ArrayLike, momentum: jax.typing.ArrayLike = 0.9 ) -> base.GradientTransformationExtraArgs: @@ -2271,6 +2294,7 @@ def sm3( ) +@functools.lru_cache(maxsize=None) def yogi( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -2330,6 +2354,7 @@ def yogi( ) +@functools.lru_cache(maxsize=None) def adamax( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -2415,6 +2440,7 @@ def adamax( ) +@functools.lru_cache(maxsize=None) def adamaxw( learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, @@ -2490,6 +2516,7 @@ def adamaxw( ) +@functools.lru_cache(maxsize=None) def rprop( learning_rate: jax.typing.ArrayLike, eta_minus: jax.typing.ArrayLike = 0.5, @@ -2561,6 +2588,7 @@ def rprop( ) +@functools.lru_cache(maxsize=None) def polyak_sgd( max_learning_rate: jax.typing.ArrayLike = 1.0, scaling: base.ScalarOrSchedule = 1.0, @@ -2652,6 +2680,7 @@ def polyak_sgd( ) +@functools.lru_cache(maxsize=None) def lbfgs( learning_rate: Optional[base.ScalarOrSchedule] = None, memory_size: int = 10, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index f2d38abb1..d782264e7 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -427,5 +427,79 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype): test_utils.assert_trees_all_equal(updates, jnp.zeros_like(grads)) +class GradientTransformationMemoizationTest(absltest.TestCase): + """Tests that optimizer factory functions are memoized for stable hashing.""" + + def test_same_args_return_same_object(self): + opt1 = alias.adam(1e-3) + opt2 = alias.adam(1e-3) + self.assertIs(opt1, opt2) + + def test_different_args_return_different_objects(self): + opt1 = alias.adam(1e-3) + opt2 = alias.adam(1e-4) + self.assertIsNot(opt1, opt2) + + def test_equal_optimizers_have_equal_hash(self): + opt1 = alias.adam(1e-3) + opt2 = alias.adam(1e-3) + self.assertEqual(hash(opt1), hash(opt2)) + self.assertEqual(opt1, opt2) + + def test_all_optimizer_aliases_memoize(self): + """Spot-check that all major optimizer aliases are memoized.""" + cases = [ + (alias.adam, [1e-3], {}), + (alias.adamw, [1e-3], {}), + (alias.sgd, [1e-3], {}), + (alias.adagrad, [1e-3], {}), + (alias.rmsprop, [1e-3], {}), + (alias.lion, [1e-3], {}), + (alias.adamax, [1e-3], {}), + (alias.radam, [1e-3], {}), + (alias.lamb, [1e-3], {}), + (alias.amsgrad, [1e-3], {}), + ] + for opt_fn, args, kwargs in cases: + with self.subTest(opt_fn.__name__): + r1 = opt_fn(*args, **kwargs) + r2 = opt_fn(*args, **kwargs) + self.assertIs(r1, r2, + f'{opt_fn.__name__} is not memoized') + + def test_chain_is_memoized(self): + """Test that chain() returns the same object for the same transforms.""" + from optax._src import combine + c1 = combine.chain( + alias.adam(1e-3), + alias.scale(-1.0), + alias.set_to_zero(), + ) + c2 = combine.chain( + alias.adam(1e-3), + alias.scale(-1.0), + alias.set_to_zero(), + ) + self.assertIs(c1, c2) + + def test_memoized_functions_dont_recompile_jit(self): + """Test that equal optimizers don't trigger JAX recompilation.""" + counters = {'traces': 0} + + @functools.partial(jax.jit, static_argnames=('opt',)) + def train_step(opt, opt_state): + counters['traces'] += 1 + return opt_state + + opt_state = alias.adam(1e-2).init({'x': jnp.zeros((100, 100))}) + opt_state = jax.block_until_ready(opt_state) + + for _ in range(3): + opt_state = train_step(alias.adam(1e-2), opt_state) + opt_state = jax.block_until_ready(opt_state) + + self.assertEqual(counters['traces'], 1) + + if __name__ == '__main__': absltest.main() diff --git a/optax/_src/base.py b/optax/_src/base.py index 5c32bb04b..2743cdeab 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -15,6 +15,7 @@ """Base interfaces and datatypes.""" from collections.abc import Callable +import functools from typing import (Any, Iterable, Mapping, NamedTuple, Optional, Protocol, Sequence, Union, runtime_checkable) @@ -226,6 +227,7 @@ def init_empty_state(params: Params) -> EmptyState: return EmptyState() +@functools.lru_cache(maxsize=None) def identity() -> GradientTransformation: """Stateless identity transformation that leaves input gradients untouched. @@ -246,6 +248,7 @@ def update_fn(updates, state, params=None): return GradientTransformation(init_empty_state, update_fn) +@functools.lru_cache(maxsize=None) def set_to_zero() -> GradientTransformation: """Stateless transformation that maps input gradients to zero. @@ -273,6 +276,7 @@ def update_fn(updates, state, params=None): return GradientTransformation(init_empty_state, update_fn) +@functools.lru_cache(maxsize=None) def stateless( f: Callable[[Updates, Optional[Params]], Updates], ) -> GradientTransformation: @@ -296,6 +300,7 @@ def update_fn(updates, state, params=None): return GradientTransformation(init_empty_state, update_fn) +@functools.lru_cache(maxsize=None) def stateless_with_tree_map( f: Callable[[jax.typing.ArrayLike, Optional[jax.typing.ArrayLike]], jax.typing.ArrayLike], @@ -326,6 +331,7 @@ def update_fn(updates, state, params=None): return GradientTransformation(init_empty_state, update_fn) +@functools.lru_cache(maxsize=None) def with_extra_args_support( tx: GradientTransformation, ) -> GradientTransformationExtraArgs: diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 1c462c58d..5d3684763 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -16,6 +16,7 @@ import collections from collections.abc import Callable, Hashable, Mapping +import functools from typing import NamedTuple, Union import jax @@ -23,6 +24,7 @@ from optax._src import wrappers +@functools.lru_cache(maxsize=None) def chain( *args: base.GradientTransformation, ) -> base.GradientTransformationExtraArgs: