Skip to content

Fix sparse initializer column sparsity#3725

Open
ishtihoss wants to merge 1 commit into
ml-explore:mainfrom
ishtihoss:fix-sparse-init-columns
Open

Fix sparse initializer column sparsity#3725
ishtihoss wants to merge 1 commit into
ml-explore:mainfrom
ishtihoss:fix-sparse-init-columns

Conversation

@ishtihoss

Copy link
Copy Markdown
Contributor

Summary

  • make nn.init.sparse zero the documented fraction of entries in each column
  • align the zero-mask orientation with PyTorch's sparse initializer
  • add a deterministic column-count regression test

Tests

  • python -m unittest test_init

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the equivalent API in pytorch that I can verify with?

@ishtihoss

Copy link
Copy Markdown
Contributor Author

The equivalent PyTorch API is torch.nn.init.sparse_:

import math
import torch
import torch.nn as nn

rows, cols = 5, 7
sparsity = 0.25

w = torch.empty(rows, cols)
nn.init.sparse_(w, sparsity=sparsity, std=1.0)

print((w == 0).sum(dim=0))
print("expected:", math.ceil(sparsity * rows))

PyTorch applies sparsity per column: for a (rows, cols) tensor it computes ceil(sparsity * rows) and zeros that many row entries independently in each column.

The MLX equivalent for the same parameters is:

init_fn = nn.init.sparse(sparsity, mean=0.0, std=1.0)
w = init_fn(mx.zeros((rows, cols)))

One API difference: PyTorch's sparse_ is in-place, hardcodes mean=0 for the normal draw, and defaults to std=0.01; MLX returns a new array and exposes mean (default 0.0, std=1.0).

@ishtihoss

Copy link
Copy Markdown
Contributor Author

I also ran a broader copy/paste check against torch.nn.init.sparse_. The semantic property being checked is: for a 2D (rows, cols) tensor, each column has exactly ceil(sparsity * rows) zeroed entries.

PyTorch reference check:

python - <<'PY'
import math
import torch

print(f"torch {torch.__version__}")
shapes = [(1, 5), (2, 2), (3, 2), (5, 7), (8, 4), (11, 3)]
sparsities = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]

for rows, cols in shapes:
    for sparsity in sparsities:
        torch.manual_seed(10_000 + rows * 100 + cols)
        x = torch.empty(rows, cols)
        torch.nn.init.sparse_(x, sparsity=sparsity, std=1.0)
        expected = math.ceil(sparsity * rows)
        col_counts = (x == 0).sum(dim=0).tolist()
        assert col_counts == [expected] * cols, (
            (rows, cols), sparsity, expected, col_counts, x
        )
print(f"column-count cases passed: {len(shapes) * len(sparsities)}")

for shape in [(3,), (2, 3, 4)]:
    try:
        torch.nn.init.sparse_(torch.empty(shape), sparsity=0.5, std=1.0)
    except ValueError as exc:
        assert "2 dimensions" in str(exc)
    else:
        raise AssertionError(f"torch accepted invalid shape {shape}")
print("invalid-rank cases passed: 2")

rows, cols, sparsity = 5, 7, 0.25
torch.manual_seed(123)
x = torch.empty(rows, cols)
torch.nn.init.sparse_(x, sparsity=sparsity, std=1.0)
print("example shape=(5,7) sparsity=0.25 torch column zero counts:", (x == 0).sum(dim=0).tolist())
print("expected per column:", math.ceil(sparsity * rows))
PY

Matching MLX check on this branch:

python - <<'PY'
import math
import mlx.core as mx
import mlx.nn as nn

print("mlx", mx.__version__)
shapes = [(1, 5), (2, 2), (3, 2), (5, 7), (8, 4), (11, 3)]
sparsities = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]

for rows, cols in shapes:
    for sparsity in sparsities:
        mx.random.seed(10_000 + rows * 100 + cols)
        result = nn.init.sparse(sparsity=sparsity, mean=1.0, std=0.0)(
            mx.zeros((rows, cols))
        )
        expected = math.ceil(sparsity * rows)
        col_counts = mx.sum(result == 0, axis=0).tolist()
        assert col_counts == [expected] * cols, (
            (rows, cols), sparsity, expected, col_counts, result
        )
print(f"column-count cases passed: {len(shapes) * len(sparsities)}")

for shape in [(3,), (2, 3, 4)]:
    try:
        nn.init.sparse(sparsity=0.5)(mx.zeros(shape))
    except ValueError as exc:
        assert "2 dimensions" in str(exc)
    else:
        raise AssertionError(f"MLX accepted invalid shape {shape}")
print("invalid-rank cases passed: 2")

rows, cols, sparsity = 5, 7, 0.25
mx.random.seed(123)
result = nn.init.sparse(sparsity=sparsity, mean=1.0, std=0.0)(mx.zeros((rows, cols)))
print("example shape=(5,7) sparsity=0.25 MLX column zero counts:", mx.sum(result == 0, axis=0).tolist())
print("expected per column:", math.ceil(sparsity * rows))
PY

The output I got locally was:

torch 2.12.1
column-count cases passed: 36
invalid-rank cases passed: 2
example shape=(5,7) sparsity=0.25 torch column zero counts: [2, 2, 2, 2, 2, 2, 2]
expected per column: 2

mlx 0.31.2
column-count cases passed: 36
invalid-rank cases passed: 2
example shape=(5,7) sparsity=0.25 MLX column zero counts: [2, 2, 2, 2, 2, 2, 2]
expected per column: 2

I used mean=1.0, std=0.0 in the MLX check to make the intentionally zeroed entries unambiguous. PyTorch's API does not expose mean, so the PyTorch reference check uses deterministic seeds with std=1.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants