From d5e673dfb905b352d5dd36f2c5a573378fcc6269 Mon Sep 17 00:00:00 2001 From: ishtihoss Date: Fri, 19 Jun 2026 12:09:25 -0700 Subject: [PATCH] Fix sparse initializer column sparsity --- python/mlx/nn/init.py | 6 +++--- python/tests/test_init.py | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index ec3ee0422a..d8a022ab70 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -385,12 +385,12 @@ def initializer(a: mx.array) -> mx.array: raise ValueError("Only tensors with 2 dimensions are supported") rows, cols = a.shape - num_zeros = int(math.ceil(sparsity * cols)) + num_zeros = int(math.ceil(sparsity * rows)) - order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1) + order = mx.argsort(mx.random.uniform(shape=a.shape), axis=0) a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) - a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0 + a[order[:num_zeros, :], mx.arange(cols).reshape(1, cols)] = 0 return a diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 0c26373199..eb1491680e 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -106,6 +106,15 @@ def test_sparse(self): with self.assertRaises(ValueError): result = initializer(mx.zeros((1,))) + for sparsity in [0.0, 0.25, 0.5, 0.75, 1.0]: + result = init.sparse(sparsity, mean=1.0, std=0.0)( + mx.array(np.empty((5, 7))) + ) + expected_zeros = int(np.ceil(sparsity * result.shape[0])) + self.assertTrue( + mx.all(mx.sum(result == 0, axis=0) == expected_zeros).item() + ) + def test_orthogonal(self): initializer = init.orthogonal(gain=1.0, dtype=mx.float32)