Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2)
if bias_correction:
count_inc = numerics.safe_increment(state.count)
Expand All @@ -141,6 +142,7 @@ def update_fn(updates, state, params=None):
else:
scaling = jax.tree.map(lambda n: 1 / (jnp.sqrt(n) + eps), nu_hat)
updates = jax.tree.map(lambda s, g: s * g, scaling, updates)
updates = optax.tree.cast_like(updates, grads)
if bias_correction:
new_state = ScaleByRmsWithCountState(count=count_inc, nu=nu)
else:
Expand Down Expand Up @@ -202,6 +204,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
mu = optax.tree.update_moment(updates, state.mu, decay, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, decay, 2)
if bias_correction:
Expand All @@ -226,6 +229,7 @@ def update_fn(updates, state, params=None):
nu_hat,
)
updates = jax.tree.map(lambda s, g: s * g, scaling, updates)
updates = optax.tree.cast_like(updates, grads)
if bias_correction:
new_state = ScaleByRStdDevWithCountState(count=count_inc, mu=mu, nu=nu)
else:
Expand Down Expand Up @@ -280,6 +284,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates # keep a reference to the original grads to cast back later
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_increment(state.count)
Expand All @@ -302,6 +307,7 @@ def update_fn(updates, state, params=None):
nu_hat,
is_leaf=lambda x: x is None,
)
updates = optax.tree.cast_like(updates, grads)
mu = optax.tree.cast(mu, mu_dtype)
nu = optax.tree.cast_like(nu, state.nu)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Expand Down Expand Up @@ -361,6 +367,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_increment(state.count)
Expand All @@ -382,6 +389,7 @@ def update_fn(updates, state, params=None):
nu_max,
is_leaf=lambda x: x is None,
)
updates = optax.tree.cast_like(updates, grads)
mu = optax.tree.cast(mu, mu_dtype)
return updates, ScaleByAmsgradState(
count=count_inc, mu=mu, nu=nu, nu_max=nu_max
Expand Down Expand Up @@ -415,12 +423,14 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
count_inc = numerics.safe_increment(state.count)
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_infinity_moment(updates, state.nu, b2, eps)
# Bias correction for mean. No bias correction needed for infinity moment.
mu_hat = optax.tree.bias_correction(mu, b1, count_inc)
updates = jax.tree.map(lambda m, v: m / v, mu_hat, nu)
updates = optax.tree.cast_like(updates, grads)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -467,6 +477,7 @@ def init_fn(params):
def update_fn(updates, state, params=None):

del params
grads = updates

def _comb(g, m):
x = (1.0 - b1) * g + b1 * m
Expand All @@ -484,6 +495,7 @@ def _comb(g, m):
)

updates_new = jax.tree.map(_comb, updates, state.mu)
updates_new = optax.tree.cast_like(updates_new, grads)
mu = optax.tree.update_moment(updates, state.mu, b2, 1)
mu = optax.tree.cast(mu, mu_dtype)
count_inc = numerics.safe_increment(state.count)
Expand Down Expand Up @@ -655,6 +667,7 @@ def init_fn(params):
def update_fn(updates, state, params=None):
"""Based on Algorithm 1 in https://arxiv.org/pdf/2208.06677v4#page=6."""
del params
grads = updates
g = updates

diff = optax.tree.where(
Expand All @@ -676,6 +689,7 @@ def update_fn(updates, state, params=None):
u = optax.tree.add_scale(m_hat, 1 - b2, v_hat)
denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat)
u = optax.tree.div(u, denom)
u = optax.tree.cast_like(u, grads)

new_state = ScaleByAdanState(
m=m,
Expand Down Expand Up @@ -730,6 +744,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
prediction_error = optax.tree.sub(updates, mu)
nu = optax.tree.update_moment_per_elem_norm(prediction_error, state.nu, b2,
Expand All @@ -751,6 +766,7 @@ def update_fn(updates, state, params=None):
nu_hat,
is_leaf=lambda x: x is None,
)
updates = optax.tree.cast_like(updates, grads)
return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -792,6 +808,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = jax.tree.map(
lambda g, v: v - (1 - b2) * jnp.sign(v - abs_sq(g)) * abs_sq(g),
Expand All @@ -807,6 +824,7 @@ def update_fn(updates, state, params=None):
nu_hat,
is_leaf=lambda x: x is None,
)
updates = optax.tree.cast_like(updates, grads)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -861,6 +879,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
grads = updates
mu = optax.tree.update_moment(updates, state.mu, b1, 1)
nu = optax.tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_increment(state.count)
Expand All @@ -882,6 +901,7 @@ def update_fn(updates, state, params=None):
_radam_update(ro, mu_hat, nu_hat),
mu_hat,
)
updates = optax.tree.cast_like(updates, grads)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -1318,6 +1338,7 @@ def update_mu(grads, params, mu, nu):
)

def update_fn(updates, state, params):
grads = updates
count_inc = numerics.safe_increment(state.count)

nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu)
Expand All @@ -1327,6 +1348,7 @@ def update_fn(updates, state, params):

mu = optax.tree.cast(mu, mu_dtype)
updates = mu
updates = optax.tree.cast_like(updates, grads)
return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu)

# pyrefly: ignore[bad-argument-type]
Expand Down
30 changes: 30 additions & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,36 @@ def test_scale_by_polyak_l1_norm(self, tol=1e-10):
print(grad, value, updates)
self.assertLess(objective(init_params - updates), tol)

@parameterized.named_parameters([
('scale_by_adam', transform.scale_by_adam, {}),
('scale_by_amsgrad', transform.scale_by_amsgrad, {}),
('scale_by_adamax', transform.scale_by_adamax, {}),
('scale_by_rms', transform.scale_by_rms, {}),
('scale_by_stddev', transform.scale_by_stddev, {}),
('scale_by_belief', transform.scale_by_belief, {}),
('scale_by_yogi', transform.scale_by_yogi, {}),
('scale_by_radam', transform.scale_by_radam, {}),
('scale_by_lion', transform.scale_by_lion, {}),
('scale_by_adan', transform.scale_by_adan, {}),
('scale_by_novograd', transform.scale_by_novograd, {}),
])
def test_mixed_precision_dtype(self, transform_constr, transform_kwargs):
"""Test that transforms output updates with the same dtype as the grads.

In mixed-precision training params are float32 while grads are bfloat16.
Accumulator states are initialized from params, so without an explicit
cast the outputs get promoted to float32. The cast to param dtype should
happen at the apply_updates step, not inside the transform.

See https://github.com/google-deepmind/optax/issues/1098.
"""
params = jnp.array([1.0, 2.0], dtype=jnp.float32)
grads = jnp.array([2.0, 4.0], dtype=jnp.bfloat16)
tx = transform_constr(**transform_kwargs)
state = tx.init(params)
updates, _ = tx.update(grads, state, params)
self.assertEqual(updates.dtype, grads.dtype)

def test_rms_match_adam(self):
"""Test scale_by_rms add_eps_in_sqrt=False matches scale_by_adam(b1=0)."""
fun = lambda x: optax.tree.norm(x, squared=True)
Expand Down
Loading