fix: memoize optimizer factory functions to avoid JIT recompilation (#353)#1693
Open
Muneerali199 wants to merge 1 commit into
Open
fix: memoize optimizer factory functions to avoid JIT recompilation (#353)#1693Muneerali199 wants to merge 1 commit into
Muneerali199 wants to merge 1 commit into
Conversation
…oogle-deepmind#353) When GradientTransformation objects are passed as static arguments to jax.jit, JAX recompiles on every call because the closures inside each GradientTransformation have different identities, producing different hashes. Fix: add @functools.lru_cache(maxsize=None) to all 29 optimizer alias factory functions in alias.py, plus chain() in _combining.py, and identity/set_to_zero/stateless/with_extra_args_support in base.py. Memoization ensures the same arguments always return the exact same GradientTransformation object, with stable identity-based hashing. JAX sees the same static argument and does not recompile. Closes google-deepmind#353
Collaborator
|
Thanks, this looks like an interesting direction! The straightforward application |
Contributor
Author
Good point. I'd handle it by wrapping lru_cache to catch TypeError from unhashable args and fall through cleanly — no crash, no id() reuse bug. For the common case (hashable args) it's cached and stable. Want me to update the PR with this? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #353
Problem
When
GradientTransformationobjects are passed as static arguments tojax.jit, JAX recompiles on every call because each call to an optimizer factory creates new closures with different identities, producing different hashes.Root Cause
GradientTransformationis aNamedTuplewhose__hash__is derived from the identity of itsinitandupdateclosure fields. Since every call tooptax.adam(1e-3)creates new closures, the hash is different each time.Solution
Add
@functools.lru_cache(maxsize=None)to all optimizer alias factory functions. This ensures that identical arguments always return the exact sameGradientTransformationobject with a stable hash.Files changed:
chain()identity(),set_to_zero(),stateless(),stateless_with_tree_map(),with_extra_args_support()GradientTransformationMemoizationTestwith 6 testsVerification
is-identical objects for identical argumentsjax.jitdoes not recompile when the same optimizer configuration is reused