Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def _update(grad, v_row, v_col, v, param, step):
new_v_col = decay_rate_t * v_col + (1.0 - decay_rate_t) * jnp.mean(
grad_sqr, axis=d1
)
new_v_row = new_v_row.astype(dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
new_v_col = new_v_col.astype(dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
new_v_row = jnp.astype(new_v_row, dtype)
new_v_col = jnp.astype(new_v_col, dtype)
reduced_d1 = d1 - 1 if d1 > d0 else d1
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
row_factor = (new_v_row / row_col_mean) ** -0.5
Expand All @@ -198,7 +198,7 @@ def _update(grad, v_row, v_col, v, param, step):
else:
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v = decay_rate_t * v + (1.0 - decay_rate_t) * grad_sqr
new_v = new_v.astype(dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
new_v = jnp.astype(new_v, dtype)
update = grad * (new_v) ** -0.5

return _UpdateResult(update, new_v_row, new_v_col, new_v)
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def fun(x):
opt = alias.lbfgs()
sol_arr, _ = _run_opt(opt, fun, init_array, maxiter=3)
sol_tree, _ = _run_opt(opt, fun, init_tree, maxiter=3)
sol_tree = jnp.stack((sol_tree[0], sol_tree[1]))
sol_tree = jnp.asarray(sol_tree)
test_utils.assert_trees_all_close(
sol_arr, sol_tree, rtol=5 * 1e-5, atol=5 * 1e-5
)
Expand Down
6 changes: 2 additions & 4 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def zakharov(x):
class BacktrackingLinesearchTest(parameterized.TestCase):

def _check_decrease_conditions(
self, fun, init_params, descent_dir, final_params, final_state, opt_args
self, fun, init_params, descent_dir, final_params, final_lr, opt_args
):
"""Check decrease conditions."""
init_value, init_grad = jax.value_and_grad(fun)(init_params)
final_value = fun(final_params)
final_lr = final_state[0]

slope = optax.tree.vdot(descent_dir, init_grad)
slope_rtol, atol, rtol = (
Expand Down Expand Up @@ -176,12 +175,11 @@ def test_linesearch(
params = update.apply_updates(init_params, updates)

self._check_decrease_conditions(
# pyrefly: ignore[bad-index]
fn,
init_params,
descent_dir,
params,
state[-1],
optax.tree.get(state, 'learning_rate'),
opt_args,
)

Expand Down
10 changes: 4 additions & 6 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _comb(g, m):
elif mode == 'refined':
# Keep small values linear, saturate to sign for large values.
return jnp.where(jnp.abs(x) < 1.0, x, jnp.sign(x))
else: # pytype: disable=unreachable
else:
raise ValueError(
f'Unknown lion mode: {mode}. '
'It needs to be one of ["hard", "smooth", "refined"].'
Expand Down Expand Up @@ -1552,14 +1552,14 @@ class ScaleByLBFGSState(NamedTuple):
updates: base.Params
diff_params_memory: base.ArrayTree
diff_updates_memory: base.ArrayTree
weights_memory: jax.typing.ArrayLike
weights_memory: jax.Array


def _precondition_by_lbfgs(
updates: base.Updates,
diff_params_memory: base.ArrayTree,
diff_updates_memory: base.ArrayTree,
weights_memory: jax.typing.ArrayLike,
weights_memory: jax.Array,
identity_scale: jax.typing.ArrayLike, # float
memory_idx: jax.typing.ArrayLike, # int
) -> base.Updates:
Expand Down Expand Up @@ -1596,14 +1596,13 @@ def _precondition_by_lbfgs(
, 1999
"""
rhos = weights_memory
memory_size = weights_memory.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501
memory_size = weights_memory.shape[0]
indices = (memory_idx + jnp.arange(memory_size)) % memory_size

def right_product(vec, idx):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
# pyrefly: ignore[bad-index]
alpha = rhos[idx] * optax.tree.real(optax.tree.vdot(dwi, vec))
vec_new = optax.tree.add_scale(vec, -alpha, dui)
vec_new = optax.tree.cast_like(vec_new, vec)
Expand All @@ -1620,7 +1619,6 @@ def left_product(vec, idx_alpha):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
# pyrefly: ignore[bad-index]
beta = rhos[idx] * optax.tree.real(optax.tree.vdot(dui, vec))
vec_new = optax.tree.add_scale(vec, alpha - beta, dwi)
vec_new = optax.tree.cast_like(vec_new, vec)
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array:
NxDxD tensor, with the same contents as `a` but with the diagonal
changed to `new_diags`.
"""
a_dim, new_diags_dim = len(a.shape), len(new_diags.shape) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
a_dim, new_diags_dim = jnp.ndim(a), jnp.ndim(new_diags)
if a_dim != 3:
raise ValueError(f'Expected `a` to be a 3D tensor, got {a_dim}D instead')
if new_diags_dim != 2:
raise ValueError(
f'Expected `new_diags` to be a 2D array, got {new_diags_dim}D instead'
)
n, d, d1 = a.shape
n_diags, d_diags = new_diags.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501
n, d, d1 = jnp.shape(a)
n_diags, d_diags = jnp.shape(new_diags)
if d != d1:
raise ValueError(
f'Shape mismatch: expected `a.shape` to be {(n, d, d)}, '
Expand All @@ -113,7 +113,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array:
indices3 = indices2

# Use numpy array setting
a = a.at[indices1, indices2, indices3].set(new_diags.flatten()) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
a = a.at[indices1, indices2, indices3].set(jnp.ravel(new_diags))
return a


Expand Down
3 changes: 2 additions & 1 deletion optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import functools
import inspect
from typing import Any, cast

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
factory = getattr(contrib, wrapper_name)
factory = functools.partial(factory, base_opt)
hparams = wrapper_kwargs
opt = factory(**hparams) # pyrefly: ignore[bad-unpacking]
opt = factory(**cast(dict[str, Any], hparams))

# Add here the hyperparameters that cannot be injected with
# inject_hyperparams.
Expand Down
5 changes: 3 additions & 2 deletions optax/contrib/_dadapt_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def update_fn(
numerator_weighted = state.numerator_weighted
count_inc = numerics.safe_increment(count)
bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc)
dlr = state.estim_lr * sched * bc
dlr = dlr.astype(numerator_weighted.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
dlr = jnp.astype(
state.estim_lr * sched * bc, jax.dtypes.result_type(numerator_weighted)
)
s_weighted = jax.tree.map(
lambda sk, eas: sk / (jnp.sqrt(eas) + eps), grad_sum, state.exp_avg_sq
)
Expand Down
31 changes: 12 additions & 19 deletions optax/contrib/_galore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp

from optax import tree
from optax._src import transform
from optax._src import update
from optax.contrib import _galore


def _tree_sum_squares(tree):
leaves = jax.tree.leaves(tree)
def _tree_sum_squares(pytree):
leaves = jax.tree.leaves(pytree)
return sum(jnp.sum(jnp.square(x.astype(jnp.float32))) for x in leaves)


Expand Down Expand Up @@ -332,22 +331,17 @@ def test_3d_memory_reduction_with_dimension_numbers(self):
weight_dimension_numbers=dim_nums,
)
state = opt.init(params)
# pyrefly: ignore[bad-index]
assert isinstance(state[0], _galore.GaLoreState)
# pytype: disable=annotation-type-mismatch
galore_state: _galore.GaLoreState = state[0]
# pytype: enable=annotation-type-mismatch
base_state = galore_state.base_optimizer_state

# Reshaped: (256, 512), m < n → right projection
# Moments should be (m, rank) = (256, 16)
# pytype: disable=attribute-error
self.assertEqual(base_state.mu["w"].shape, (embed_dim, rank))
self.assertEqual(base_state.nu["w"].shape, (embed_dim, rank))
mu = tree.get(state, "mu")
nu = tree.get(state, "nu")
self.assertEqual(mu["w"].shape, (embed_dim, rank))
self.assertEqual(nu["w"].shape, (embed_dim, rank))

# Verify memory savings
full_size = embed_dim * num_heads * head_dim
moment_size = base_state.mu["w"].size
moment_size = mu["w"].size
self.assertLess(moment_size, full_size)
# pytype: enable=attribute-error

Expand Down Expand Up @@ -437,15 +431,14 @@ def test_dimension_numbers_right_projection(self):
weight_dimension_numbers=dim_nums,
)
state = opt.init(params)
galore_state = state[0] # pyrefly: ignore[bad-index]
# pytype: disable=attribute-error
base_state = galore_state.base_optimizer_state

# m=32 < n=512 → right projection
# Projector: (n, rank) = (512, 8)
# Moments: (m, rank) = (32, 8)
self.assertEqual(galore_state.projector["w"].shape, (512, rank))
self.assertEqual(base_state.mu["w"].shape, (32, rank))
projector = tree.get(state, "projector")
mu = tree.get(state, "mu")
self.assertEqual(projector["w"].shape, (512, rank))
self.assertEqual(mu["w"].shape, (32, rank))
# pytype: enable=attribute-error

def test_single_dimension_number_applied_to_all(self):
Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/_mechanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def add_weight_decay(gi, pi):
clipped_h = jax.lax.clamp(-state.m, jnp.ones_like(state.m) * h, state.m)
betas = jnp.array(
[1.0 - 0.1**betai for betai in range(1, num_betas + 1)],
dtype=state.s.dtype, # pytype: disable=attribute-error # jax-arraylike
dtype=jax.dtypes.result_type(state.s),
)

m = jnp.maximum(betas * state.m, jnp.abs(h) + eps)
Expand Down
8 changes: 4 additions & 4 deletions optax/contrib/_momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def update_fn(
# initialize at first gradient, and loss
bt = jnp.where(count == 0, 0.0, beta)
barf = bt * state.barf + (1 - bt) * jnp.asarray(
value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501
value, dtype=jax.dtypes.result_type(state.barf)
)
exp_avg = jax.tree.map(
lambda ea, g: bt * ea + (1 - bt) * g, state.exp_avg, updates
Expand Down Expand Up @@ -297,7 +297,7 @@ def update_fn(
count = state.count
count_inc = numerics.safe_increment(count)
barf = b1 * state.barf + (1 - b1) * jnp.asarray(
value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501
value, dtype=jax.dtypes.result_type(state.barf)
)
exp_avg = jax.tree.map(
lambda ea, g: b1 * ea + (1 - b1) * g, state.exp_avg, updates
Expand All @@ -307,7 +307,7 @@ def update_fn(
state.exp_avg_sq,
updates,
)
bc2 = jnp.asarray(1 - b2**count_inc, dtype=barf.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
bc2 = jnp.asarray(1 - b2**count_inc, dtype=jax.dtypes.result_type(barf))
precond = jax.tree.map(lambda eas: eps + jnp.sqrt(eas / bc2), exp_avg_sq)
exp_avg_weighted = jax.tree.map(
lambda ea, prec: ea / prec, exp_avg, precond
Expand All @@ -316,7 +316,7 @@ def update_fn(
gamma = b1 * state.gamma + (1 - b1) * optax.tree.vdot(updates, params)
iprod = optax.tree.vdot(exp_avg, params)
alpha = learning_rate(count) if callable(learning_rate) else learning_rate
bc1 = jnp.asarray(1 - b1**count_inc, dtype=barf.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
bc1 = jnp.asarray(1 - b1**count_inc, dtype=jax.dtypes.result_type(barf))
# Reset lower bound
if adapt_lower_bound:
cap = (1 + alpha * weight_decay) * (barf - gamma) + iprod
Expand Down
18 changes: 9 additions & 9 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ class MuonDimensionNumbers(NamedTuple):
def _normalize_axes(x: jax.Array, dim_nums: MuonDimensionNumbers) -> tuple[
tuple[int, ...], tuple[int, ...]]:
"""Normalize axes in dimension numbers to two tuples of non-negative ints."""
if isinstance(dim_nums.reduction_axis, int):
dim_nums = dim_nums._replace(reduction_axis=(dim_nums.reduction_axis,))
# pyrefly: ignore[not-iterable]
reduction_axes = tuple(ax % x.ndim for ax in dim_nums.reduction_axis)

if isinstance(dim_nums.output_axis, int):
dim_nums = dim_nums._replace(output_axis=(dim_nums.output_axis,))
# pyrefly: ignore[not-iterable]
output_axes = tuple(ax % x.ndim for ax in dim_nums.output_axis)
reduction_axis = dim_nums.reduction_axis
if isinstance(reduction_axis, int):
reduction_axis = (reduction_axis,)
reduction_axes = tuple(ax % x.ndim for ax in reduction_axis)

output_axis = dim_nums.output_axis
if isinstance(output_axis, int):
output_axis = (output_axis,)
output_axes = tuple(ax % x.ndim for ax in output_axis)
return reduction_axes, output_axes


Expand Down
9 changes: 4 additions & 5 deletions optax/contrib/_muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,14 @@ def test_muon_orthogonalization_modes(self, preconditioning, shape):
preconditioning=preconditioning)
state = opt.init(params)
updates, _ = opt.update(params, state, params=params)
w_update = jax.tree.leaves(updates)[0]

# Check shape preservation
# pyrefly: ignore[bad-index, missing-attribute]
self.assertEqual(updates['w'].shape, shape)
self.assertEqual(w_update.shape, shape)

# Check Near-Orthogonality (Spectral Norm Constraint)
if shape[0] == shape[1]:
# pyrefly: ignore[bad-index, no-matching-overload]
s = jnp.linalg.svd(updates['w'], compute_uv=False)
s = jnp.linalg.svd(w_update, compute_uv=False)
max_s = jnp.max(s)
min_s = jnp.min(s)
self.assertLess(max_s, 2.0, msg=f'Max singular value {max_s} too high')
Expand Down Expand Up @@ -412,7 +411,7 @@ def test_orthogonality(self, preconditioning):
params = {'w': jnp.eye(8) * 2.0}
opt = _muon.muon(learning_rate=0.1, preconditioning=preconditioning)
updates, _ = opt.update(params, opt.init(params), params)
w_update = updates['w'] # pyrefly: ignore[bad-index]
w_update = jax.tree.leaves(updates)[0]

for leaf in jax.tree_util.tree_leaves(updates):
self.assertFalse(jnp.isnan(leaf).any(), 'Found NaN values in updates')
Expand Down
4 changes: 3 additions & 1 deletion optax/contrib/_prodigy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def update_fn(
estim_lr = state.estim_lr
numerator_weighted = state.numerator_weighted
bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc)
dlr = jnp.asarray(estim_lr * sched * bc, dtype=estim_lr.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
dlr = jnp.asarray(
estim_lr * sched * bc, dtype=jax.dtypes.result_type(estim_lr)
)
dg = jax.tree.map(lambda g: estim_lr * g, updates)
param_diff = jax.tree.map(lambda p0, p: p0 - p, params0, params)
numerator_acum = optax.tree.vdot(updates, param_diff)
Expand Down
18 changes: 13 additions & 5 deletions optax/contrib/_reduce_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def _update_scale(state):
avg_value < (1 - rtol) * state.best_value - atol, 1, 0
)
new_best_value: jax.Array = jnp.where(
has_improved, avg_value.astype(state.best_value.dtype), state.best_value
has_improved,
jnp.astype(avg_value, jax.dtypes.result_type(state.best_value)),
state.best_value,
)
curr_plateau_count = jnp.where(
has_improved, 0, numerics.safe_increment(state.plateau_count)
Expand Down Expand Up @@ -154,11 +156,13 @@ def not_in_cooldown():
)
new_state = ReduceLROnPlateauState(
plateau_count=new_plateau_count,
best_value=new_best_value.astype(state.best_value.dtype),
best_value=jnp.astype(
new_best_value, jax.dtypes.result_type(state.best_value)
),
scale=new_scale,
cooldown_count=new_cooldown_count,
count=jnp.asarray(0, dtype=jnp.int32),
avg_value=jnp.asarray(0.0, avg_value.dtype),
avg_value=jnp.asarray(0.0, jax.dtypes.result_type(avg_value)),
)
return new_state

Expand All @@ -175,10 +179,14 @@ def update_fn(
count = state.count
new_count = numerics.safe_increment(count)
new_avg_value = (
count * state.avg_value + jnp.astype(value, state.avg_value.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
count * state.avg_value
+ jnp.astype(value, jax.dtypes.result_type(state.avg_value))
) / new_count
new_state = state._replace(
avg_value=new_avg_value.astype(state.avg_value.dtype), count=new_count # pytype: disable=attribute-error # jax-arraylike # noqa: E501
avg_value=jnp.astype(
new_avg_value, jax.dtypes.result_type(state.avg_value)
),
count=new_count,
)

new_state = jax.lax.cond(
Expand Down
Loading
Loading