From 2f2dc5b11e4ae54efeace806bb7c70d9b4b01e52 Mon Sep 17 00:00:00 2001 From: Dylan Tirandaz Date: Mon, 23 Mar 2026 01:17:42 -0500 Subject: [PATCH] Extend muon to support mask_callable weight dimension numbers Add support for weight_dimension_numbers trees where each leaf is a callable, resolved via _masking._mask_callable. This extends both scale_by_muon and scale_by_shape to handle this case by mapping each callable leaf over updates. Resolves TODO(rdyro) at _muon.py:478. --- optax/contrib/_muon.py | 9 ++++++++- optax/contrib/_muon_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index e5563a479..acad2f54c 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -182,6 +182,10 @@ def update_fn(updates, state, params=None): # Populate weight_dim_nums if it's a callable. Use updates instead of # actual params since only shapes matter and params may not be provided. resolved_weight_dim_nums = weight_dimension_numbers(updates) + elif _masking._mask_callable(weight_dimension_numbers): + resolved_weight_dim_nums = jax.tree.map( + lambda fn: fn(updates), weight_dimension_numbers + ) else: resolved_weight_dim_nums = weight_dimension_numbers @@ -475,11 +479,14 @@ def init_fn(params): def update_fn(updates, state, params=None): del params - # TODO(rdyro): extend to _masking._mask_callable if callable(weight_dimension_numbers): # Populate weight_dim_nums if it's a callable. Use updates instead of # actual params since only shapes matter and params may not be provided. resolved_weight_dim_nums = weight_dimension_numbers(updates) + elif _masking._mask_callable(weight_dimension_numbers): + resolved_weight_dim_nums = jax.tree.map( + lambda fn: fn(updates), weight_dimension_numbers + ) else: resolved_weight_dim_nums = weight_dimension_numbers diff --git a/optax/contrib/_muon_test.py b/optax/contrib/_muon_test.py index 02c02f78d..ae05d371a 100644 --- a/optax/contrib/_muon_test.py +++ b/optax/contrib/_muon_test.py @@ -143,6 +143,41 @@ def weight_dim_nums_fn(params): # pylint: disable=function-redefined state = opt.init(params) _, _ = opt.update(params, state, params=params) + @parameterized.named_parameters( + ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') + ) + def test_mask_callable_weight_dim_nums(self, preconditioning): + """Test weight_dimension_numbers as a tree of callables.""" + params = {'w1': jnp.ones((10, 10)), 'w2': jnp.ones((2, 10))} + + # A tree where each leaf is a callable returning MuonDimensionNumbers. + weight_dim_nums_tree = { + 'w1': lambda updates: _muon.MuonDimensionNumbers(0, 1), + 'w2': lambda updates: None, + } + opt = _muon.muon( + learning_rate=1e-3, + preconditioning=preconditioning, + muon_weight_dimension_numbers=weight_dim_nums_tree, + ) + state = opt.init(params) + updates, _ = opt.update(params, state, params=params) + + # Compare against a single callable that returns the same tree. + def weight_dim_nums_fn(updates): + return {'w1': _muon.MuonDimensionNumbers(0, 1), 'w2': None} + + opt_ref = _muon.muon( + learning_rate=1e-3, + preconditioning=preconditioning, + muon_weight_dimension_numbers=weight_dim_nums_fn, + ) + state_ref = opt_ref.init(params) + updates_ref, _ = opt_ref.update(params, state_ref, params=params) + + test_utils.assert_trees_all_close(updates, updates_ref, rtol=1e-8, + atol=1e-8) + @parameterized.named_parameters( ('frobenius', 'frobenius'), ('aol', 'aol'), ('schatten', 'schatten') )