Skip to content

GramMuon and Sharding for Muon #1660

Description

@pillow37

Hi,
thank you for your valuable work on this library.

I have 2 feature proposals that I think would work with a single implementation:

  • Sharding for Muon. Also see this blogpost on parallelizing muon across tensor parallel matrices.
  • GramMuon: a faster drop-in replacement for Muon using Gram Newton Schulz iterations.

A lambda for Newton-Schulz

I think it would be easiest from a user perspective to directly pass a lambda that specifies the Newton-Schulz iterations instead of specifying the MuonDimensionNumbers.
This way, one could handle all kinds of implementations very flexibly.

def newton_schulz(X, default_iters, ):
    a, b = ...
    for i in range(default_iters):
        ...  # can handle sharded matmuls with explicitly specifying sharding rules, etc.
        X_new = ...
    return X_new

this can be plugged in here to replace the MuonDimensionNumbers(-2, -1)

def _path_contains_words(path, words):
    def get_name_or_key(item):
        if hasattr(item, 'name'):
            return item.name
        if hasattr(item, 'key'):
            return str(item.key)
        if hasattr(item, 'idx'):
            return str(item.idx)
        raise ValueError
    pretty_path = '.'.join([get_name_or_key(k) for k in path])
    return any(word in pretty_path for word in words)

def get_muon_dimensions(
        state, 
        deny_list=['token_embedding', 'lm_head', 
                           'bias', 'scale', 'sinks'],
        return_shape_for_inspection=False,
    ):
    '''Masks out non-matrices and defines 
        https://optax.readthedocs.io/en/stable/api/contrib.html#optax.contrib.muon
    '''

    def fn(path, leaf):
        is_in_deny_list = _path_contains_words(path, deny_list)
        dims = None
        if is_in_deny_list:
            dims = None
        elif ndim >= 2:
            dims = MuonDimensionNumbers(-2, -1)

        if return_shape_for_inspection:
            return (leaf.shape, dims is not None)
        else:
            return dims
    res = jax.tree_util.tree_map_with_path(fn, state, is_leaf=lambda value: isinstance(value, nnx.Param))
    return res

This would make it easy to use linear algebra optimizations for Muon, such as GramMuon.
E.g. for GramMuon they provide this pseudo code

Gram Newton-Schulz

Input: $X \in \mathbb{R}^{n \times m}$ with $n \leq m$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$

  1. $X \gets X / (\|X\|_{F} + \epsilon)$   // Normalize sing vals to $[0, 1]$.   $\epsilon = 10^{-7}$
  2. $X \gets \texttt{float16}(X)$   // Cast to half precision for speed
  3. If $m < n$:   $X \gets X^\top$   // Trick to make $XX^\top$ cheaper
  4. $R_0 \gets XX^\top$
  5. $Q_0 \gets I$
  6. For $t = 1, \ldots, 5$:
    • If $t = 3$:   // Restart to stabilize
      • $X \gets Q_2 X$
      • $R_2 \gets XX^\top$
      • $Q_2 \gets I$
    • $Z_t \gets b_t R_{t-1} + c_t R_{t-1}^2$
    • $Q_t \gets Q_{t-1} Z_t + a_t Q_{t-1}$
    • $RZ_t \gets R_{t-1} Z_t + a_t R_{t-1}$
    • $R_t \gets Z_t (RZ_t) + a_t (RZ_t)$
  7. $X \gets Q_4 X$
  8. If $m < n$:   $X \gets X^\top$   // Undo trick
  9. Return $X$

Thank you for your efforts.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions