diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index f2d38abb1..fd3c04661 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -250,15 +250,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # See also https://github.com/google-deepmind/optax/issues/412. opt_factory = _get_opt(self, opt_name) opt = opt_factory(**opt_kwargs) - if opt_name == 'adafactor': - # Adafactor wrapped in inject_hyperparams currently needs a static - # argument to be specified in order to be jittable. See issue - # https://github.com/google-deepmind/optax/issues/412. - opt_inject = _inject.inject_hyperparams( - opt_factory, static_args=('min_dim_size_to_factor',) - )(**opt_kwargs) - else: - opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) + opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] @@ -286,6 +278,20 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( rtol=1e-4, ) + def test_adafactor_inject_hyperparams_jit_factored(self): + """Adafactor with inject_hyperparams must be jittable for factored params. + + Regression test for https://github.com/google-deepmind/optax/issues/412. + Uses large params so that the factored second-moment path is triggered + (requires both dimensions >= min_dim_size_to_factor=128). + """ + optimizer = _inject.inject_hyperparams(alias.adafactor)(learning_rate=0.01) + params = {'w': jnp.ones((200, 200))} + grads = {'w': jnp.ones((200, 200))} + opt_state = jax.jit(optimizer.init)(params) + updates, _ = jax.jit(optimizer.update)(grads, opt_state, params) + self.assertEqual(updates['w'].shape, (200, 200)) + @parameterized.product( params_dtype=('bfloat16', 'float32', 'complex64', None), state_dtype=('bfloat16', 'float32', 'complex64', None), diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 6688587cf..3b876f6bc 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -152,13 +152,21 @@ def wrapped_transform( sched_hps, numeric_hps, other_hps = {}, {}, {} for name, value in bound_arguments.arguments.items(): if name in static_args or isinstance(value, bool): + # booleans and explicitly-declared static args are never traced other_hps[name] = value elif isinstance(value, base.StatefulSchedule): sched_hps[name] = value elif callable(value): # pyrefly: ignore[bad-argument-type] sched_hps[name] = WrappedSchedule(value) - elif isinstance(value, (int, float, jax.Array, np.ndarray)): + elif isinstance(value, int): + # Plain Python ints (e.g. min_dim_size_to_factor, memory_size) are + # used in Python-level control flow and shape decisions inside many + # optimizers. Tracing them as JAX arrays causes ConcretizationTypeError + # under jit. Users who want to pass a schedulable integer should use a + # JAX array (e.g. jnp.int32(10)), which is handled by the branch below. + other_hps[name] = value + elif isinstance(value, (float, jax.Array, np.ndarray)): numeric_hps[name] = value else: other_hps[name] = value