From 6f93ad30473c6ea5bdbef07537691ecdb5dd989d Mon Sep 17 00:00:00 2001 From: nileshpatil6 Date: Sun, 31 May 2026 00:09:11 +0530 Subject: [PATCH] fix(transform): cast updates to grad dtype for mixed-precision support In mixed-precision training params are float32 while grads are bfloat16. Optimizer accumulator states are initialized from params (float32), so arithmetic with bfloat16 grads promotes the result to float32. The cast back to param dtype should happen at the apply_updates step, not inside the transform. This adds an explicit cast of the output updates to match the input grad dtype in scale_by_adam, scale_by_amsgrad, scale_by_adamax, scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi, scale_by_radam, scale_by_lion, scale_by_adan, and scale_by_novograd. Adds a parameterized test that verifies each affected transform produces bfloat16 updates when given float32 params and bfloat16 grads. Fixes #1098 --- optax/_src/transform.py | 22 ++++++++++++++++++++++ optax/_src/transform_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index ec2e8a2b5..c1313bf85 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -129,6 +129,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2) if bias_correction: count_inc = numerics.safe_increment(state.count) @@ -141,6 +142,7 @@ def update_fn(updates, state, params=None): else: scaling = jax.tree.map(lambda n: 1 / (jnp.sqrt(n) + eps), nu_hat) updates = jax.tree.map(lambda s, g: s * g, scaling, updates) + updates = optax.tree.cast_like(updates, grads) if bias_correction: new_state = ScaleByRmsWithCountState(count=count_inc, nu=nu) else: @@ -202,6 +204,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, decay, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2) if bias_correction: @@ -226,6 +229,7 @@ def update_fn(updates, state, params=None): nu_hat, ) updates = jax.tree.map(lambda s, g: s * g, scaling, updates) + updates = optax.tree.cast_like(updates, grads) if bias_correction: new_state = ScaleByRStdDevWithCountState(count=count_inc, mu=mu, nu=nu) else: @@ -280,6 +284,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates # keep a reference to the original grads to cast back later mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -302,6 +307,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) mu = optax.tree.cast(mu, mu_dtype) nu = optax.tree.cast_like(nu, state.nu) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) @@ -361,6 +367,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -382,6 +389,7 @@ def update_fn(updates, state, params=None): nu_max, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) mu = optax.tree.cast(mu, mu_dtype) return updates, ScaleByAmsgradState( count=count_inc, mu=mu, nu=nu, nu_max=nu_max @@ -415,12 +423,14 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates count_inc = numerics.safe_increment(state.count) mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_infinity_moment(updates, state.nu, b2, eps) # Bias correction for mean. No bias correction needed for infinity moment. mu_hat = optax.tree.bias_correction(mu, b1, count_inc) updates = jax.tree.map(lambda m, v: m / v, mu_hat, nu) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -467,6 +477,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates def _comb(g, m): x = (1.0 - b1) * g + b1 * m @@ -484,6 +495,7 @@ def _comb(g, m): ) updates_new = jax.tree.map(_comb, updates, state.mu) + updates_new = optax.tree.cast_like(updates_new, grads) mu = optax.tree.update_moment(updates, state.mu, b2, 1) mu = optax.tree.cast(mu, mu_dtype) count_inc = numerics.safe_increment(state.count) @@ -655,6 +667,7 @@ def init_fn(params): def update_fn(updates, state, params=None): """Based on Algorithm 1 in https://arxiv.org/pdf/2208.06677v4#page=6.""" del params + grads = updates g = updates diff = optax.tree.where( @@ -676,6 +689,7 @@ def update_fn(updates, state, params=None): u = optax.tree.add_scale(m_hat, 1 - b2, v_hat) denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat) u = optax.tree.div(u, denom) + u = optax.tree.cast_like(u, grads) new_state = ScaleByAdanState( m=m, @@ -730,6 +744,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) prediction_error = optax.tree.sub(updates, mu) nu = optax.tree.update_moment_per_elem_norm(prediction_error, state.nu, b2, @@ -751,6 +766,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -792,6 +808,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = jax.tree.map( lambda g, v: v - (1 - b2) * jnp.sign(v - abs_sq(g)) * abs_sq(g), @@ -807,6 +824,7 @@ def update_fn(updates, state, params=None): nu_hat, is_leaf=lambda x: x is None, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -861,6 +879,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params + grads = updates mu = optax.tree.update_moment(updates, state.mu, b1, 1) nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_increment(state.count) @@ -882,6 +901,7 @@ def update_fn(updates, state, params=None): _radam_update(ro, mu_hat, nu_hat), mu_hat, ) + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn) @@ -1318,6 +1338,7 @@ def update_mu(grads, params, mu, nu): ) def update_fn(updates, state, params): + grads = updates count_inc = numerics.safe_increment(state.count) nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu) @@ -1327,6 +1348,7 @@ def update_fn(updates, state, params): mu = optax.tree.cast(mu, mu_dtype) updates = mu + updates = optax.tree.cast_like(updates, grads) return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu) # pyrefly: ignore[bad-argument-type] diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index f521045ff..2c8807e9e 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -219,6 +219,36 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10): print(grad, value, updates) self.assertLess(objective(init_params - updates), tol) + @parameterized.named_parameters([ + ('scale_by_adam', transform.scale_by_adam, {}), + ('scale_by_amsgrad', transform.scale_by_amsgrad, {}), + ('scale_by_adamax', transform.scale_by_adamax, {}), + ('scale_by_rms', transform.scale_by_rms, {}), + ('scale_by_stddev', transform.scale_by_stddev, {}), + ('scale_by_belief', transform.scale_by_belief, {}), + ('scale_by_yogi', transform.scale_by_yogi, {}), + ('scale_by_radam', transform.scale_by_radam, {}), + ('scale_by_lion', transform.scale_by_lion, {}), + ('scale_by_adan', transform.scale_by_adan, {}), + ('scale_by_novograd', transform.scale_by_novograd, {}), + ]) + def test_mixed_precision_dtype(self, transform_constr, transform_kwargs): + """Test that transforms output updates with the same dtype as the grads. + + In mixed-precision training params are float32 while grads are bfloat16. + Accumulator states are initialized from params, so without an explicit + cast the outputs get promoted to float32. The cast to param dtype should + happen at the apply_updates step, not inside the transform. + + See https://github.com/google-deepmind/optax/issues/1098. + """ + params = jnp.array([1.0, 2.0], dtype=jnp.float32) + grads = jnp.array([2.0, 4.0], dtype=jnp.bfloat16) + tx = transform_constr(**transform_kwargs) + state = tx.init(params) + updates, _ = tx.update(grads, state, params) + self.assertEqual(updates.dtype, grads.dtype) + def test_rms_match_adam(self): """Test scale_by_rms add_eps_in_sqrt=False matches scale_by_adam(b1=0).""" fun = lambda x: optax.tree.norm(x, squared=True)