Skip to content

fix(transform): cast updates to grad dtype for mixed-precision support#1687

Open
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/optimizer-dtype-tracks-grads
Open

fix(transform): cast updates to grad dtype for mixed-precision support#1687
nileshpatil6 wants to merge 1 commit into
google-deepmind:mainfrom
nileshpatil6:fix/optimizer-dtype-tracks-grads

Conversation

@nileshpatil6

Copy link
Copy Markdown

Fixes #1098

Problem

In mixed-precision training, params are float32 while grads are bfloat16. Optimizer
accumulator states (mu, nu) are initialized from params so they start as float32. When
update is called with bfloat16 grads, JAX type promotion causes the output updates to
be float32 instead of bfloat16. The cast back to param dtype should happen at the
apply_updates step, not inside the transform.

This is the gap identified in #1098 after #1060 was merged: #1060 ensured updates
match params dtype, but in mixed precision grads and params have different dtypes and
updates should track the grad dtype.

Fix

Each affected update_fn now casts the output updates to match the dtype of the
incoming gradients. Accumulator states are left in their computed dtype.

Transforms changed: scale_by_adam, scale_by_amsgrad, scale_by_adamax,
scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi,
scale_by_radam, scale_by_lion, scale_by_adan, scale_by_novograd.

Test

Added test_mixed_precision_dtype in transform_test.py. It gives each transform
float32 params and bfloat16 grads and checks that updates come back bfloat16.

All existing tests pass.

In mixed-precision training params are float32 while grads are bfloat16.
Optimizer accumulator states are initialized from params (float32), so
arithmetic with bfloat16 grads promotes the result to float32. The
cast back to param dtype should happen at the apply_updates step, not
inside the transform.

This adds an explicit cast of the output updates to match the input
grad dtype in scale_by_adam, scale_by_amsgrad, scale_by_adamax,
scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi,
scale_by_radam, scale_by_lion, scale_by_adan, and scale_by_novograd.

Adds a parameterized test that verifies each affected transform
produces bfloat16 updates when given float32 params and bfloat16 grads.

Fixes google-deepmind#1098
@rdyro

rdyro commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

I'm not entirely sure this represents best practice. I'm hesitant to cast gradients to a lower precision without the user doing it explicitly as it might lead to regret.

@nileshpatil6

Copy link
Copy Markdown
Author

Good point, and I agree the implicit downcast is the risky part here. My intent was narrower than the patch makes it look: the issue (#1098) is that even when a user is already running in mixed precision (float32 params, bfloat16 grads), the accumulators get initialized from the params dtype, so the returned updates silently come back as float32 instead of matching the incoming grads. The cast was meant to preserve the grad dtype the caller already chose, not to introduce a new downcast.

But you are right that cast_like(updates, grads) is too blunt: if grads happen to be lower precision than the user actually wants for the update, this hides that rather than surfacing it. So I would rather not land this as-is.

Two alternatives I think are safer:

  1. Only preserve dtype when it would otherwise be lost, i.e. initialize the moment accumulators from the grads dtype (mu/nu in scale_by_adam etc.) so the arithmetic never promotes in the first place, instead of casting at the end.
  2. Leave the default untouched and make this opt-in via an explicit dtype argument on the affected transforms.

Do you have a preference? Happy to rework toward (1), which keeps the existing behavior for the normal same-dtype case and just stops the unwanted promotion in the mixed-precision case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Updates dtype do not need to match params dtype but only grads dtype a priori

2 participants