Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions optax/contrib/_prodigy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
Konstantin Mishchenko and Aaron Defazio. A new variant of D-Adapt Adam that
adapts the learning rate faster.
"""
from typing import NamedTuple, Optional
from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Union

import jax
import jax.numpy as jnp
from optax._src import base
Expand Down Expand Up @@ -51,6 +53,8 @@ def prodigy(
estim_lr_coef: jax.typing.ArrayLike = 1.0,
weight_decay: jax.typing.ArrayLike = 0.0,
safeguard_warmup: bool = False,
weight_decay_mask: Optional[
Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformationExtraArgs:
"""Learning rate free AdamW with Prodigy.

Expand All @@ -74,6 +78,11 @@ def prodigy(
with add_decayed_weights.
safeguard_warmup: Remove lr from the denominator of D estimate to avoid
issues during warm-up stage. Off by default.
weight_decay_mask: A tree with same structure as (or a prefix of) the params
PyTree, or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, ``True`` for leaves/subtrees you want to
apply the weight decay to, and ``False`` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.

Returns:
A :class:`optax.GradientTransformation` object.
Expand Down Expand Up @@ -153,14 +162,38 @@ def update_fn(
denominator = optax.tree.sum(jax.tree.map(jnp.abs, grad_sum))
lr_estimate = estim_lr_coef * numerator_weighted / denominator
estim_lr = jnp.maximum(state.estim_lr, lr_estimate)

p_update = jax.tree.map(
# pyrefly: ignore[unsupported-operation]
lambda ea, eas, p: -weight_decay * dlr * p
- dlr * ea / (jnp.sqrt(eas) + estim_lr * eps),
lambda ea, eas: -dlr * ea / (jnp.sqrt(eas) + estim_lr * eps),
exp_avg,
exp_avg_sq,
params,
)

# Resolve weight decay mask.
if weight_decay_mask is not None:
# pyrefly: ignore[not-callable]
mask_tree = (
weight_decay_mask(params)
if callable(weight_decay_mask)
else weight_decay_mask
)
p_update = jax.tree.map(
lambda u, p, m: jnp.where(
# pyrefly: ignore[unsupported-operation]
m, u - weight_decay * dlr * p, u
),
p_update,
params,
mask_tree,
)
else:
p_update = jax.tree.map(
# pyrefly: ignore[unsupported-operation]
lambda u, p: u - weight_decay * dlr * p,
p_update,
params,
)

new_state = ProdigyState(
exp_avg,
exp_avg_sq,
Expand Down
Loading