Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def update_fn(updates, state, params=None):
# Populate weight_dim_nums if it's a callable. Use updates instead of
# actual params since only shapes matter and params may not be provided.
resolved_weight_dim_nums = weight_dimension_numbers(updates)
elif _masking._mask_callable(weight_dimension_numbers):
resolved_weight_dim_nums = jax.tree.map(
lambda fn: fn(updates), weight_dimension_numbers
)
else:
resolved_weight_dim_nums = weight_dimension_numbers

Expand Down Expand Up @@ -475,11 +479,14 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
# TODO(rdyro): extend to _masking._mask_callable
if callable(weight_dimension_numbers):
# Populate weight_dim_nums if it's a callable. Use updates instead of
# actual params since only shapes matter and params may not be provided.
resolved_weight_dim_nums = weight_dimension_numbers(updates)
elif _masking._mask_callable(weight_dimension_numbers):
resolved_weight_dim_nums = jax.tree.map(
lambda fn: fn(updates), weight_dimension_numbers
)
else:
resolved_weight_dim_nums = weight_dimension_numbers

Expand Down
35 changes: 35 additions & 0 deletions optax/contrib/_muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,41 @@ def weight_dim_nums_fn(params): # pylint: disable=function-redefined
state = opt.init(params)
_, _ = opt.update(params, state, params=params)

@parameterized.named_parameters(
('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten')
)
def test_mask_callable_weight_dim_nums(self, preconditioning):
"""Test weight_dimension_numbers as a tree of callables."""
params = {'w1': jnp.ones((10, 10)), 'w2': jnp.ones((2, 10))}

# A tree where each leaf is a callable returning MuonDimensionNumbers.
weight_dim_nums_tree = {
'w1': lambda updates: _muon.MuonDimensionNumbers(0, 1),
'w2': lambda updates: None,
}
opt = _muon.muon(
learning_rate=1e-3,
preconditioning=preconditioning,
muon_weight_dimension_numbers=weight_dim_nums_tree,
)
state = opt.init(params)
updates, _ = opt.update(params, state, params=params)

# Compare against a single callable that returns the same tree.
def weight_dim_nums_fn(updates):
return {'w1': _muon.MuonDimensionNumbers(0, 1), 'w2': None}

opt_ref = _muon.muon(
learning_rate=1e-3,
preconditioning=preconditioning,
muon_weight_dimension_numbers=weight_dim_nums_fn,
)
state_ref = opt_ref.init(params)
updates_ref, _ = opt_ref.update(params, state_ref, params=params)

test_utils.assert_trees_all_close(updates, updates_ref, rtol=1e-8,
atol=1e-8)

@parameterized.named_parameters(
('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten')
)
Expand Down
Loading