Fix sparse initializer column sparsity#3725
Conversation
zcbenz
left a comment
There was a problem hiding this comment.
What is the equivalent API in pytorch that I can verify with?
|
The equivalent PyTorch API is import math
import torch
import torch.nn as nn
rows, cols = 5, 7
sparsity = 0.25
w = torch.empty(rows, cols)
nn.init.sparse_(w, sparsity=sparsity, std=1.0)
print((w == 0).sum(dim=0))
print("expected:", math.ceil(sparsity * rows))PyTorch applies sparsity per column: for a The MLX equivalent for the same parameters is: init_fn = nn.init.sparse(sparsity, mean=0.0, std=1.0)
w = init_fn(mx.zeros((rows, cols)))One API difference: PyTorch's |
|
I also ran a broader copy/paste check against PyTorch reference check: python - <<'PY'
import math
import torch
print(f"torch {torch.__version__}")
shapes = [(1, 5), (2, 2), (3, 2), (5, 7), (8, 4), (11, 3)]
sparsities = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]
for rows, cols in shapes:
for sparsity in sparsities:
torch.manual_seed(10_000 + rows * 100 + cols)
x = torch.empty(rows, cols)
torch.nn.init.sparse_(x, sparsity=sparsity, std=1.0)
expected = math.ceil(sparsity * rows)
col_counts = (x == 0).sum(dim=0).tolist()
assert col_counts == [expected] * cols, (
(rows, cols), sparsity, expected, col_counts, x
)
print(f"column-count cases passed: {len(shapes) * len(sparsities)}")
for shape in [(3,), (2, 3, 4)]:
try:
torch.nn.init.sparse_(torch.empty(shape), sparsity=0.5, std=1.0)
except ValueError as exc:
assert "2 dimensions" in str(exc)
else:
raise AssertionError(f"torch accepted invalid shape {shape}")
print("invalid-rank cases passed: 2")
rows, cols, sparsity = 5, 7, 0.25
torch.manual_seed(123)
x = torch.empty(rows, cols)
torch.nn.init.sparse_(x, sparsity=sparsity, std=1.0)
print("example shape=(5,7) sparsity=0.25 torch column zero counts:", (x == 0).sum(dim=0).tolist())
print("expected per column:", math.ceil(sparsity * rows))
PYMatching MLX check on this branch: python - <<'PY'
import math
import mlx.core as mx
import mlx.nn as nn
print("mlx", mx.__version__)
shapes = [(1, 5), (2, 2), (3, 2), (5, 7), (8, 4), (11, 3)]
sparsities = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0]
for rows, cols in shapes:
for sparsity in sparsities:
mx.random.seed(10_000 + rows * 100 + cols)
result = nn.init.sparse(sparsity=sparsity, mean=1.0, std=0.0)(
mx.zeros((rows, cols))
)
expected = math.ceil(sparsity * rows)
col_counts = mx.sum(result == 0, axis=0).tolist()
assert col_counts == [expected] * cols, (
(rows, cols), sparsity, expected, col_counts, result
)
print(f"column-count cases passed: {len(shapes) * len(sparsities)}")
for shape in [(3,), (2, 3, 4)]:
try:
nn.init.sparse(sparsity=0.5)(mx.zeros(shape))
except ValueError as exc:
assert "2 dimensions" in str(exc)
else:
raise AssertionError(f"MLX accepted invalid shape {shape}")
print("invalid-rank cases passed: 2")
rows, cols, sparsity = 5, 7, 0.25
mx.random.seed(123)
result = nn.init.sparse(sparsity=sparsity, mean=1.0, std=0.0)(mx.zeros((rows, cols)))
print("example shape=(5,7) sparsity=0.25 MLX column zero counts:", mx.sum(result == 0, axis=0).tolist())
print("expected per column:", math.ceil(sparsity * rows))
PYThe output I got locally was: I used |
Summary
Tests