Hi,
thank you for your valuable work on this library.
I have 2 feature proposals that I think would work with a single implementation:
- Sharding for Muon. Also see this blogpost on parallelizing muon across tensor parallel matrices.
- GramMuon: a faster drop-in replacement for Muon using Gram Newton Schulz iterations.
A lambda for Newton-Schulz
I think it would be easiest from a user perspective to directly pass a lambda that specifies the Newton-Schulz iterations instead of specifying the MuonDimensionNumbers.
This way, one could handle all kinds of implementations very flexibly.
def newton_schulz(X, default_iters, ):
a, b = ...
for i in range(default_iters):
... # can handle sharded matmuls with explicitly specifying sharding rules, etc.
X_new = ...
return X_new
this can be plugged in here to replace the MuonDimensionNumbers(-2, -1)
def _path_contains_words(path, words):
def get_name_or_key(item):
if hasattr(item, 'name'):
return item.name
if hasattr(item, 'key'):
return str(item.key)
if hasattr(item, 'idx'):
return str(item.idx)
raise ValueError
pretty_path = '.'.join([get_name_or_key(k) for k in path])
return any(word in pretty_path for word in words)
def get_muon_dimensions(
state,
deny_list=['token_embedding', 'lm_head',
'bias', 'scale', 'sinks'],
return_shape_for_inspection=False,
):
'''Masks out non-matrices and defines
https://optax.readthedocs.io/en/stable/api/contrib.html#optax.contrib.muon
'''
def fn(path, leaf):
is_in_deny_list = _path_contains_words(path, deny_list)
dims = None
if is_in_deny_list:
dims = None
elif ndim >= 2:
dims = MuonDimensionNumbers(-2, -1)
if return_shape_for_inspection:
return (leaf.shape, dims is not None)
else:
return dims
res = jax.tree_util.tree_map_with_path(fn, state, is_leaf=lambda value: isinstance(value, nnx.Param))
return res
This would make it easy to use linear algebra optimizations for Muon, such as GramMuon.
E.g. for GramMuon they provide this pseudo code
Gram Newton-Schulz
Input: $X \in \mathbb{R}^{n \times m}$ with $n \leq m$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$
-
$X \gets X / (\|X\|_{F} + \epsilon)$ // Normalize sing vals to $[0, 1]$. $\epsilon = 10^{-7}$
-
$X \gets \texttt{float16}(X)$ // Cast to half precision for speed
- If $m < n$: $X \gets X^\top$ // Trick to make $XX^\top$ cheaper
- $R_0 \gets XX^\top$
- $Q_0 \gets I$
- For $t = 1, \ldots, 5$:
- If $t = 3$: // Restart to stabilize
- $X \gets Q_2 X$
- $R_2 \gets XX^\top$
- $Q_2 \gets I$
- $Z_t \gets b_t R_{t-1} + c_t R_{t-1}^2$
- $Q_t \gets Q_{t-1} Z_t + a_t Q_{t-1}$
- $RZ_t \gets R_{t-1} Z_t + a_t R_{t-1}$
- $R_t \gets Z_t (RZ_t) + a_t (RZ_t)$
- $X \gets Q_4 X$
- If $m < n$: $X \gets X^\top$ // Undo trick
- Return $X$
Thank you for your efforts.
Hi,
thank you for your valuable work on this library.
I have 2 feature proposals that I think would work with a single implementation:
A lambda for Newton-Schulz
I think it would be easiest from a user perspective to directly pass a lambda that specifies the Newton-Schulz iterations instead of specifying the MuonDimensionNumbers.
This way, one could handle all kinds of implementations very flexibly.
this can be plugged in here to replace the
MuonDimensionNumbers(-2, -1)This would make it easy to use linear algebra optimizations for Muon, such as GramMuon.
E.g. for GramMuon they provide this pseudo code
Thank you for your efforts.