Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optax/_src/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def fun(x):
opt = alias.lbfgs()
sol_arr, _ = _run_opt(opt, fun, init_array, maxiter=3)
sol_tree, _ = _run_opt(opt, fun, init_tree, maxiter=3)
sol_tree = jnp.stack((sol_tree[0], sol_tree[1]))
sol_tree = jnp.asarray(sol_tree)
test_utils.assert_trees_all_close(
sol_arr, sol_tree, rtol=5 * 1e-5, atol=5 * 1e-5
)
Expand Down
6 changes: 2 additions & 4 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def zakharov(x):
class BacktrackingLinesearchTest(parameterized.TestCase):

def _check_decrease_conditions(
self, fun, init_params, descent_dir, final_params, final_state, opt_args
self, fun, init_params, descent_dir, final_params, final_lr, opt_args
):
"""Check decrease conditions."""
init_value, init_grad = jax.value_and_grad(fun)(init_params)
final_value = fun(final_params)
final_lr = final_state[0]

slope = optax.tree.vdot(descent_dir, init_grad)
slope_rtol, atol, rtol = (
Expand Down Expand Up @@ -176,12 +175,11 @@ def test_linesearch(
params = update.apply_updates(init_params, updates)

self._check_decrease_conditions(
# pyrefly: ignore[bad-index]
fn,
init_params,
descent_dir,
params,
state[-1],
optax.tree.get(state, 'learning_rate'),
opt_args,
)

Expand Down
10 changes: 4 additions & 6 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _comb(g, m):
elif mode == 'refined':
# Keep small values linear, saturate to sign for large values.
return jnp.where(jnp.abs(x) < 1.0, x, jnp.sign(x))
else: # pytype: disable=unreachable
else:
raise ValueError(
f'Unknown lion mode: {mode}. '
'It needs to be one of ["hard", "smooth", "refined"].'
Expand Down Expand Up @@ -1552,14 +1552,14 @@ class ScaleByLBFGSState(NamedTuple):
updates: base.Params
diff_params_memory: base.ArrayTree
diff_updates_memory: base.ArrayTree
weights_memory: jax.typing.ArrayLike
weights_memory: jax.Array


def _precondition_by_lbfgs(
updates: base.Updates,
diff_params_memory: base.ArrayTree,
diff_updates_memory: base.ArrayTree,
weights_memory: jax.typing.ArrayLike,
weights_memory: jax.Array,
identity_scale: jax.typing.ArrayLike, # float
memory_idx: jax.typing.ArrayLike, # int
) -> base.Updates:
Expand Down Expand Up @@ -1596,14 +1596,13 @@ def _precondition_by_lbfgs(
, 1999
"""
rhos = weights_memory
memory_size = weights_memory.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501
memory_size = weights_memory.shape[0]
indices = (memory_idx + jnp.arange(memory_size)) % memory_size

def right_product(vec, idx):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
# pyrefly: ignore[bad-index]
alpha = rhos[idx] * optax.tree.real(optax.tree.vdot(dwi, vec))
vec_new = optax.tree.add_scale(vec, -alpha, dui)
vec_new = optax.tree.cast_like(vec_new, vec)
Expand All @@ -1620,7 +1619,6 @@ def left_product(vec, idx_alpha):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
# pyrefly: ignore[bad-index]
beta = rhos[idx] * optax.tree.real(optax.tree.vdot(dui, vec))
vec_new = optax.tree.add_scale(vec, alpha - beta, dwi)
vec_new = optax.tree.cast_like(vec_new, vec)
Expand Down
3 changes: 2 additions & 1 deletion optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import functools
import inspect
from typing import Any, cast

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
factory = getattr(contrib, wrapper_name)
factory = functools.partial(factory, base_opt)
hparams = wrapper_kwargs
opt = factory(**hparams) # pyrefly: ignore[bad-unpacking]
opt = factory(**cast(dict[str, Any], hparams))

# Add here the hyperparameters that cannot be injected with
# inject_hyperparams.
Expand Down
31 changes: 12 additions & 19 deletions optax/contrib/_galore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp

from optax import tree
from optax._src import transform
from optax._src import update
from optax.contrib import _galore


def _tree_sum_squares(tree):
leaves = jax.tree.leaves(tree)
def _tree_sum_squares(pytree):
leaves = jax.tree.leaves(pytree)
return sum(jnp.sum(jnp.square(x.astype(jnp.float32))) for x in leaves)


Expand Down Expand Up @@ -332,22 +331,17 @@ def test_3d_memory_reduction_with_dimension_numbers(self):
weight_dimension_numbers=dim_nums,
)
state = opt.init(params)
# pyrefly: ignore[bad-index]
assert isinstance(state[0], _galore.GaLoreState)
# pytype: disable=annotation-type-mismatch
galore_state: _galore.GaLoreState = state[0]
# pytype: enable=annotation-type-mismatch
base_state = galore_state.base_optimizer_state

# Reshaped: (256, 512), m < n → right projection
# Moments should be (m, rank) = (256, 16)
# pytype: disable=attribute-error
self.assertEqual(base_state.mu["w"].shape, (embed_dim, rank))
self.assertEqual(base_state.nu["w"].shape, (embed_dim, rank))
mu = tree.get(state, "mu")
nu = tree.get(state, "nu")
self.assertEqual(mu["w"].shape, (embed_dim, rank))
self.assertEqual(nu["w"].shape, (embed_dim, rank))

# Verify memory savings
full_size = embed_dim * num_heads * head_dim
moment_size = base_state.mu["w"].size
moment_size = mu["w"].size
self.assertLess(moment_size, full_size)
# pytype: enable=attribute-error

Expand Down Expand Up @@ -437,15 +431,14 @@ def test_dimension_numbers_right_projection(self):
weight_dimension_numbers=dim_nums,
)
state = opt.init(params)
galore_state = state[0] # pyrefly: ignore[bad-index]
# pytype: disable=attribute-error
base_state = galore_state.base_optimizer_state

# m=32 < n=512 → right projection
# Projector: (n, rank) = (512, 8)
# Moments: (m, rank) = (32, 8)
self.assertEqual(galore_state.projector["w"].shape, (512, rank))
self.assertEqual(base_state.mu["w"].shape, (32, rank))
projector = tree.get(state, "projector")
mu = tree.get(state, "mu")
self.assertEqual(projector["w"].shape, (512, rank))
self.assertEqual(mu["w"].shape, (32, rank))
# pytype: enable=attribute-error

def test_single_dimension_number_applied_to_all(self):
Expand Down
18 changes: 9 additions & 9 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ class MuonDimensionNumbers(NamedTuple):
def _normalize_axes(x: jax.Array, dim_nums: MuonDimensionNumbers) -> tuple[
tuple[int, ...], tuple[int, ...]]:
"""Normalize axes in dimension numbers to two tuples of non-negative ints."""
if isinstance(dim_nums.reduction_axis, int):
dim_nums = dim_nums._replace(reduction_axis=(dim_nums.reduction_axis,))
# pyrefly: ignore[not-iterable]
reduction_axes = tuple(ax % x.ndim for ax in dim_nums.reduction_axis)

if isinstance(dim_nums.output_axis, int):
dim_nums = dim_nums._replace(output_axis=(dim_nums.output_axis,))
# pyrefly: ignore[not-iterable]
output_axes = tuple(ax % x.ndim for ax in dim_nums.output_axis)
reduction_axis = dim_nums.reduction_axis
if isinstance(reduction_axis, int):
reduction_axis = (reduction_axis,)
reduction_axes = tuple(ax % x.ndim for ax in reduction_axis)

output_axis = dim_nums.output_axis
if isinstance(output_axis, int):
output_axis = (output_axis,)
output_axes = tuple(ax % x.ndim for ax in output_axis)
return reduction_axes, output_axes


Expand Down
9 changes: 4 additions & 5 deletions optax/contrib/_muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,14 @@ def test_muon_orthogonalization_modes(self, preconditioning, shape):
preconditioning=preconditioning)
state = opt.init(params)
updates, _ = opt.update(params, state, params=params)
w_update = jax.tree.leaves(updates)[0]

# Check shape preservation
# pyrefly: ignore[bad-index, missing-attribute]
self.assertEqual(updates['w'].shape, shape)
self.assertEqual(w_update.shape, shape)

# Check Near-Orthogonality (Spectral Norm Constraint)
if shape[0] == shape[1]:
# pyrefly: ignore[bad-index, no-matching-overload]
s = jnp.linalg.svd(updates['w'], compute_uv=False)
s = jnp.linalg.svd(w_update, compute_uv=False)
max_s = jnp.max(s)
min_s = jnp.min(s)
self.assertLess(max_s, 2.0, msg=f'Max singular value {max_s} too high')
Expand Down Expand Up @@ -412,7 +411,7 @@ def test_orthogonality(self, preconditioning):
params = {'w': jnp.eye(8) * 2.0}
opt = _muon.muon(learning_rate=0.1, preconditioning=preconditioning)
updates, _ = opt.update(params, opt.init(params), params)
w_update = updates['w'] # pyrefly: ignore[bad-index]
w_update = jax.tree.leaves(updates)[0]

for leaf in jax.tree_util.tree_leaves(updates):
self.assertFalse(jnp.isnan(leaf).any(), 'Found NaN values in updates')
Expand Down
30 changes: 20 additions & 10 deletions optax/contrib/_reduce_on_plateau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from optax import tree
from optax._src import test_utils
from optax.contrib import _reduce_on_plateau

Expand Down Expand Up @@ -64,8 +65,10 @@ def test_learning_rate_reduced_after_cooldown_period_is_over(
)

# Check that learning rate is reduced
# pyrefly: ignore[not-iterable]
scale, best_value, plateau_count, cooldown_count, *_ = state
scale = tree.get(state, 'scale')
best_value = tree.get(state, 'best_value')
plateau_count = tree.get(state, 'plateau_count')
cooldown_count = tree.get(state, 'cooldown_count')
test_utils.assert_trees_all_close(scale, 0.1)
test_utils.assert_trees_all_close(best_value, 1.0)
test_utils.assert_trees_all_close(plateau_count, 0)
Expand All @@ -78,8 +81,10 @@ def test_learning_rate_reduced_after_cooldown_period_is_over(
)

# Check that cooldown_count is decremented
# pyrefly: ignore[not-iterable]
scale, best_value, plateau_count, cooldown_count, *_ = state
scale = tree.get(state, 'scale')
best_value = tree.get(state, 'best_value')
plateau_count = tree.get(state, 'plateau_count')
cooldown_count = tree.get(state, 'cooldown_count')
test_utils.assert_trees_all_close(scale, 0.1)
test_utils.assert_trees_all_close(best_value, 1.0)
test_utils.assert_trees_all_close(plateau_count, 0)
Expand Down Expand Up @@ -108,8 +113,9 @@ def test_learning_rate_is_not_reduced(self, enable_x64):
)

# Check that plateau_count resets
# pyrefly: ignore[not-iterable]
scale, best_value, plateau_count, *_ = new_state
scale = tree.get(new_state, 'scale')
best_value = tree.get(new_state, 'best_value')
plateau_count = tree.get(new_state, 'plateau_count')
test_utils.assert_trees_all_close(plateau_count, 0)
test_utils.assert_trees_all_close(scale, 0.1)
test_utils.assert_trees_all_close(best_value, 0.1)
Expand Down Expand Up @@ -138,8 +144,10 @@ def test_learning_rate_not_reduced_during_cooldown(self, enable_x64):

# Check that learning rate is not reduced and
# plateau_count is not incremented
# pyrefly: ignore[not-iterable]
scale, best_value, plateau_count, cooldown_count, *_ = new_state
scale = tree.get(new_state, 'scale')
best_value = tree.get(new_state, 'best_value')
plateau_count = tree.get(new_state, 'plateau_count')
cooldown_count = tree.get(new_state, 'cooldown_count')
test_utils.assert_trees_all_close(scale, 0.1)
test_utils.assert_trees_all_close(best_value, 1.0)
test_utils.assert_trees_all_close(plateau_count, 0)
Expand Down Expand Up @@ -174,8 +182,10 @@ def test_learning_rate_not_reduced_after_end_scale_is_reached(
)

# Check that learning rate is not reduced
# pyrefly: ignore[not-iterable]
scale, best_value, plateau_count, cooldown_count, *_ = state
scale = tree.get(state, 'scale')
best_value = tree.get(state, 'best_value')
plateau_count = tree.get(state, 'plateau_count')
cooldown_count = tree.get(state, 'cooldown_count')
test_utils.assert_trees_all_close(scale, 0.01)
test_utils.assert_trees_all_close(best_value, 0.1)
test_utils.assert_trees_all_close(plateau_count, 0)
Expand Down
Loading
Loading