fix(transform): cast updates to grad dtype for mixed-precision support#1687
fix(transform): cast updates to grad dtype for mixed-precision support#1687nileshpatil6 wants to merge 1 commit into
Conversation
In mixed-precision training params are float32 while grads are bfloat16. Optimizer accumulator states are initialized from params (float32), so arithmetic with bfloat16 grads promotes the result to float32. The cast back to param dtype should happen at the apply_updates step, not inside the transform. This adds an explicit cast of the output updates to match the input grad dtype in scale_by_adam, scale_by_amsgrad, scale_by_adamax, scale_by_rms, scale_by_stddev, scale_by_belief, scale_by_yogi, scale_by_radam, scale_by_lion, scale_by_adan, and scale_by_novograd. Adds a parameterized test that verifies each affected transform produces bfloat16 updates when given float32 params and bfloat16 grads. Fixes google-deepmind#1098
|
I'm not entirely sure this represents best practice. I'm hesitant to cast gradients to a lower precision without the user doing it explicitly as it might lead to regret. |
|
Good point, and I agree the implicit downcast is the risky part here. My intent was narrower than the patch makes it look: the issue (#1098) is that even when a user is already running in mixed precision (float32 params, bfloat16 grads), the accumulators get initialized from the params dtype, so the returned updates silently come back as float32 instead of matching the incoming grads. The cast was meant to preserve the grad dtype the caller already chose, not to introduce a new downcast. But you are right that cast_like(updates, grads) is too blunt: if grads happen to be lower precision than the user actually wants for the update, this hides that rather than surfacing it. So I would rather not land this as-is. Two alternatives I think are safer:
Do you have a preference? Happy to rework toward (1), which keeps the existing behavior for the normal same-dtype case and just stops the unwanted promotion in the mixed-precision case. |
Fixes #1098
Problem
In mixed-precision training, params are float32 while grads are bfloat16. Optimizer
accumulator states (mu, nu) are initialized from params so they start as float32. When
update is called with bfloat16 grads, JAX type promotion causes the output updates to
be float32 instead of bfloat16. The cast back to param dtype should happen at the
apply_updatesstep, not inside the transform.This is the gap identified in #1098 after #1060 was merged: #1060 ensured updates
match params dtype, but in mixed precision grads and params have different dtypes and
updates should track the grad dtype.
Fix
Each affected
update_fnnow casts the output updates to match the dtype of theincoming gradients. Accumulator states are left in their computed dtype.
Transforms changed:
scale_by_adam,scale_by_amsgrad,scale_by_adamax,scale_by_rms,scale_by_stddev,scale_by_belief,scale_by_yogi,scale_by_radam,scale_by_lion,scale_by_adan,scale_by_novograd.Test
Added
test_mixed_precision_dtypeintransform_test.py. It gives each transformfloat32 params and bfloat16 grads and checks that updates come back bfloat16.
All existing tests pass.