diff --git a/optax/_src/transform.py b/optax/_src/transform.py index ec2e8a2b5..c1313bf85 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -129,6 +129,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2) if bias_correction: count_inc = numerics.safe_increment(state.count) @@ -141,6 +142,7 @@ def update_fn(updates, state, params=None): else: scaling = jax.tree.map(lambda n: 1 / (jnp.sqrt(n) + eps), nu_hat) updates = jax.tree.map(lambda s, g: s * g, scaling, updates) + updates = optax.tree.cast_like(updates, grads) if bias_correction: new_state = ScaleByRmsWithCountState(count=count_inc, nu=nu) else: @@ -202,6 +204,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, decay, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2) if bias_correction: @@ -226,6 +229,7 @@ def update_fn(updates, state, params=None): nu_hat, ) updates = jax.tree.map(lambda s, g: s * g, scaling, updates) + updates = optax.tree.cast_like(updates, grads) if bias_correction: new_state = ScaleByRStdDevWithCountState(count=count_inc, mu=mu, nu=nu) else: @@ -280,6 +284,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates # keep a reference to the original grads to cast back later mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -302,6 +307,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) mu = optax.tree.cast(mu, mu_dtype) nu = optax.tree.cast_like(nu, state.nu) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) @@ -361,6 +367,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -382,6 +389,7 @@ def update_fn(updates, state, params=None): nu_max, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) mu = optax.tree.cast(mu, mu_dtype) return updates, ScaleByAmsgradState( count=count_inc, mu=mu, nu=nu, nu_max=nu_max @@ -415,12 +423,14 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates count_inc = numerics.safe_increment(state.count) mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_infinity_moment(updates, state.nu, b2, eps) # Bias correction for mean. No bias correction needed for infinity moment. mu_hat = optax.tree.bias_correction(mu, b1, count_inc) updates = jax.tree.map(lambda m, v: m / v, mu_hat, nu) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -467,6 +477,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates def _comb(g, m): x = (1.0 - b1) * g + b1 * m @@ -484,6 +495,7 @@ def _comb(g, m): ) updates_new = jax.tree.map(_comb, updates, state.mu) + updates_new = optax.tree.cast_like(updates_new, grads) mu = optax.tree.update_moment(updates, state.mu, b2, 1) mu = optax.tree.cast(mu, mu_dtype) count_inc = numerics.safe_increment(state.count) @@ -655,6 +667,7 @@ def init_fn(params): def update_fn(updates, state, params=None): """Based on Algorithm 1 in https://arxiv.org/pdf/2208.06677v4#page=6.""" del params + grads = updates g = updates diff = optax.tree.where( @@ -676,6 +689,7 @@ def update_fn(updates, state, params=None): u = optax.tree.add_scale(m_hat, 1 - b2, v_hat) denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat) u = optax.tree.div(u, denom) + u = optax.tree.cast_like(u, grads) new_state = ScaleByAdanState( m=m, @@ -730,6 +744,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) prediction_error = optax.tree.sub(updates, mu) nu = optax.tree.update_moment_per_elem_norm(prediction_error, state.nu, b2, @@ -751,6 +766,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -792,6 +808,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = jax.tree.map( lambda g, v: v - (1 - b2) * jnp.sign(v - abs_sq(g)) * abs_sq(g), @@ -807,6 +824,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -861,6 +879,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -882,6 +901,7 @@ def update_fn(updates, state, params=None): _radam_update(ro, mu_hat, nu_hat), mu_hat, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -1318,6 +1338,7 @@ def update_mu(grads, params, mu, nu): ) def update_fn(updates, state, params): + grads = updates count_inc = numerics.safe_increment(state.count) nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu) @@ -1327,6 +1348,7 @@ def update_fn(updates, state, params): mu = optax.tree.cast(mu, mu_dtype) updates = mu + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu) # pyrefly: ignore[bad-argument-type] diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index f521045ff..2c8807e9e 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -219,6 +219,36 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10): print(grad, value, updates) self.assertLess(objective(init_params - updates), tol) + @parameterized.named_parameters([ + ('scale_by_adam', transform.scale_by_adam, {}), + ('scale_by_amsgrad', transform.scale_by_amsgrad, {}), + ('scale_by_adamax', transform.scale_by_adamax, {}), + ('scale_by_rms', transform.scale_by_rms, {}), + ('scale_by_stddev', transform.scale_by_stddev, {}), + ('scale_by_belief', transform.scale_by_belief, {}), + ('scale_by_yogi', transform.scale_by_yogi, {}), + ('scale_by_radam', transform.scale_by_radam, {}), + ('scale_by_lion', transform.scale_by_lion, {}), + ('scale_by_adan', transform.scale_by_adan, {}), + ('scale_by_novograd', transform.scale_by_novograd, {}), + ]) + def test_mixed_precision_dtype(self, transform_constr, transform_kwargs): + """Test that transforms output updates with the same dtype as the grads. + + In mixed-precision training params are float32 while grads are bfloat16. + Accumulator states are initialized from params, so without an explicit + cast the outputs get promoted to float32. The cast to param dtype should + happen at the apply_updates step, not inside the transform. + + See https://github.com/google-deepmind/optax/issues/1098. + """ + params = jnp.array([1.0, 2.0], dtype=jnp.float32) + grads = jnp.array([2.0, 4.0], dtype=jnp.bfloat16) + tx = transform_constr(**transform_kwargs) + state = tx.init(params) + updates, _ = tx.update(grads, state, params) + self.assertEqual(updates.dtype, grads.dtype) + def test_rms_match_adam(self): """Test scale_by_rms add_eps_in_sqrt=False matches scale_by_adam(b1=0).""" fun = lambda x: optax.tree.norm(x, squared=True)