Skip to content

fix: memoize optimizer factory functions to avoid JIT recompilation (#353)#1693

Open
Muneerali199 wants to merge 1 commit into
google-deepmind:mainfrom
Muneerali199:fix/memoize-optimizers-for-jit-hash
Open

fix: memoize optimizer factory functions to avoid JIT recompilation (#353)#1693
Muneerali199 wants to merge 1 commit into
google-deepmind:mainfrom
Muneerali199:fix/memoize-optimizers-for-jit-hash

Conversation

@Muneerali199

Copy link
Copy Markdown
Contributor

Fixes #353

Problem

When GradientTransformation objects are passed as static arguments to jax.jit, JAX recompiles on every call because each call to an optimizer factory creates new closures with different identities, producing different hashes.

Root Cause

GradientTransformation is a NamedTuple whose __hash__ is derived from the identity of its init and update closure fields. Since every call to optax.adam(1e-3) creates new closures, the hash is different each time.

Solution

Add @functools.lru_cache(maxsize=None) to all optimizer alias factory functions. This ensures that identical arguments always return the exact same GradientTransformation object with a stable hash.

Files changed:

  • optax/_src/alias.py: Added memoization to all 29 optimizer factory functions
  • optax/transforms/_combining.py: Added to chain()
  • optax/_src/base.py: Added to identity(), set_to_zero(), stateless(), stateless_with_tree_map(), with_extra_args_support()
  • optax/_src/alias_test.py: Added GradientTransformationMemoizationTest with 6 tests

Verification

  • All memoized factory functions return is-identical objects for identical arguments
  • Different arguments still produce different objects
  • JAX jax.jit does not recompile when the same optimizer configuration is reused

…oogle-deepmind#353)

When GradientTransformation objects are passed as static arguments to
jax.jit, JAX recompiles on every call because the closures inside each
GradientTransformation have different identities, producing different
hashes.

Fix: add @functools.lru_cache(maxsize=None) to all 29 optimizer alias
factory functions in alias.py, plus chain() in _combining.py, and
identity/set_to_zero/stateless/with_extra_args_support in base.py.

Memoization ensures the same arguments always return the exact same
GradientTransformation object, with stable identity-based hashing. JAX
sees the same static argument and does not recompile.

Closes google-deepmind#353
@rdyro

rdyro commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Thanks, this looks like an interesting direction!

The straightforward application lru_cache is probably going to break on any dynamic data which is not hashable, what are you thinking as far as solving that problem?

@Muneerali199

Copy link
Copy Markdown
Contributor Author

Thanks, this looks like an interesting direction!

The straightforward application lru_cache is probably going to break on any dynamic data which is not hashable, what are you thinking as far as solving that problem?

Good point. I'd handle it by wrapping lru_cache to catch TypeError from unhashable args and fall through cleanly — no crash, no id() reuse bug. For the common case (hashable args) it's cached and stable. Want me to update the PR with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add __hash__ function to avoid unnecessary recompilations

2 participants