From cd65f389d5e6f98e50d515a10f6109e4a2d16e31 Mon Sep 17 00:00:00 2001 From: James Martens Date: Tue, 23 Jun 2026 16:25:43 -0700 Subject: [PATCH] Adding mask for weight decay in Prodigy-AdamW implementation PiperOrigin-RevId: 936958302 --- optax/contrib/_prodigy.py | 43 ++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/optax/contrib/_prodigy.py b/optax/contrib/_prodigy.py index d7d7a88ad..94482f369 100644 --- a/optax/contrib/_prodigy.py +++ b/optax/contrib/_prodigy.py @@ -19,7 +19,9 @@ Konstantin Mishchenko and Aaron Defazio. A new variant of D-Adapt Adam that adapts the learning rate faster. """ -from typing import NamedTuple, Optional +from collections.abc import Callable +from typing import Any, NamedTuple, Optional, Union + import jax import jax.numpy as jnp from optax._src import base @@ -51,6 +53,8 @@ def prodigy( estim_lr_coef: jax.typing.ArrayLike = 1.0, weight_decay: jax.typing.ArrayLike = 0.0, safeguard_warmup: bool = False, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]]] = None, ) -> base.GradientTransformationExtraArgs: """Learning rate free AdamW with Prodigy. @@ -74,6 +78,11 @@ def prodigy( with add_decayed_weights. safeguard_warmup: Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, ``True`` for leaves/subtrees you want to + apply the weight decay to, and ``False`` for those you want to skip. Note + that the Adam gradient transformations are applied to all parameters. Returns: A :class:`optax.GradientTransformation` object. @@ -153,14 +162,38 @@ def update_fn( denominator = optax.tree.sum(jax.tree.map(jnp.abs, grad_sum)) lr_estimate = estim_lr_coef * numerator_weighted / denominator estim_lr = jnp.maximum(state.estim_lr, lr_estimate) + p_update = jax.tree.map( - # pyrefly: ignore[unsupported-operation] - lambda ea, eas, p: -weight_decay * dlr * p - - dlr * ea / (jnp.sqrt(eas) + estim_lr * eps), + lambda ea, eas: -dlr * ea / (jnp.sqrt(eas) + estim_lr * eps), exp_avg, exp_avg_sq, - params, ) + + # Resolve weight decay mask. + if weight_decay_mask is not None: + # pyrefly: ignore[not-callable] + mask_tree = ( + weight_decay_mask(params) + if callable(weight_decay_mask) + else weight_decay_mask + ) + p_update = jax.tree.map( + lambda u, p, m: jnp.where( + # pyrefly: ignore[unsupported-operation] + m, u - weight_decay * dlr * p, u + ), + p_update, + params, + mask_tree, + ) + else: + p_update = jax.tree.map( + # pyrefly: ignore[unsupported-operation] + lambda u, p: u - weight_decay * dlr * p, + p_update, + params, + ) + new_state = ProdigyState( exp_avg, exp_avg_sq,