diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index ec3ee0422a..8e089c6c68 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -385,12 +385,13 @@ def initializer(a: mx.array) -> mx.array: raise ValueError("Only tensors with 2 dimensions are supported") rows, cols = a.shape - num_zeros = int(math.ceil(sparsity * cols)) + num_zeros = int(math.ceil(sparsity * rows)) - order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1) + order = mx.argsort(mx.random.uniform(shape=a.shape), axis=0) a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) - a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0 + zeros = mx.zeros((num_zeros, cols), dtype=dtype) + a = mx.put_along_axis(a, order[:num_zeros, :], zeros, axis=0) return a diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 0c26373199..b4b9a80a4a 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -100,12 +100,27 @@ def test_sparse(self): with self.subTest(shape=shape): self.assertEqual(result.shape, shape) self.assertEqual(result.dtype, dtype) - self.assertEqual( - (mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True - ) with self.assertRaises(ValueError): result = initializer(mx.zeros((1,))) + initializer = init.sparse(sparsity, mean=1.0, std=0.0, dtype=dtype) + for shape in [(3, 2), (2, 2), (4, 3)]: + result = initializer(mx.array(np.empty(shape))) + expected_zeros = int(np.ceil(sparsity * result.shape[0])) + with self.subTest(shape=shape, dtype=dtype): + self.assertTrue( + mx.all(mx.sum(result == 0, axis=0) == expected_zeros).item() + ) + + for sparsity in [0.0, 0.25, 0.5, 0.75, 1.0]: + result = init.sparse(sparsity, mean=1.0, std=0.0)( + mx.array(np.empty((5, 7))) + ) + expected_zeros = int(np.ceil(sparsity * result.shape[0])) + self.assertTrue( + mx.all(mx.sum(result == 0, axis=0) == expected_zeros).item() + ) + def test_orthogonal(self): initializer = init.orthogonal(gain=1.0, dtype=mx.float32)