Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
# See also https://github.com/google-deepmind/optax/issues/412.
opt_factory = _get_opt(self, opt_name)
opt = opt_factory(**opt_kwargs)
if opt_name == 'adafactor':
# Adafactor wrapped in inject_hyperparams currently needs a static
# argument to be specified in order to be jittable. See issue
# https://github.com/google-deepmind/optax/issues/412.
opt_inject = _inject.inject_hyperparams(
opt_factory, static_args=('min_dim_size_to_factor',)
)(**opt_kwargs)
else:
opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs)
opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs)

params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))]
Expand Down Expand Up @@ -286,6 +278,20 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
rtol=1e-4,
)

def test_adafactor_inject_hyperparams_jit_factored(self):
"""Adafactor with inject_hyperparams must be jittable for factored params.

Regression test for https://github.com/google-deepmind/optax/issues/412.
Uses large params so that the factored second-moment path is triggered
(requires both dimensions >= min_dim_size_to_factor=128).
"""
optimizer = _inject.inject_hyperparams(alias.adafactor)(learning_rate=0.01)
params = {'w': jnp.ones((200, 200))}
grads = {'w': jnp.ones((200, 200))}
opt_state = jax.jit(optimizer.init)(params)
updates, _ = jax.jit(optimizer.update)(grads, opt_state, params)
self.assertEqual(updates['w'].shape, (200, 200))

@parameterized.product(
params_dtype=('bfloat16', 'float32', 'complex64', None),
state_dtype=('bfloat16', 'float32', 'complex64', None),
Expand Down
10 changes: 9 additions & 1 deletion optax/schedules/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,21 @@ def wrapped_transform(
sched_hps, numeric_hps, other_hps = {}, {}, {}
for name, value in bound_arguments.arguments.items():
if name in static_args or isinstance(value, bool):
# booleans and explicitly-declared static args are never traced
other_hps[name] = value
elif isinstance(value, base.StatefulSchedule):
sched_hps[name] = value
elif callable(value):
# pyrefly: ignore[bad-argument-type]
sched_hps[name] = WrappedSchedule(value)
elif isinstance(value, (int, float, jax.Array, np.ndarray)):
elif isinstance(value, int):
# Plain Python ints (e.g. min_dim_size_to_factor, memory_size) are
# used in Python-level control flow and shape decisions inside many
# optimizers. Tracing them as JAX arrays causes ConcretizationTypeError
# under jit. Users who want to pass a schedulable integer should use a
# JAX array (e.g. jnp.int32(10)), which is handled by the branch below.
other_hps[name] = value
elif isinstance(value, (float, jax.Array, np.ndarray)):
numeric_hps[name] = value
else:
other_hps[name] = value
Expand Down
Loading