From 28423a8703e95eb1ae4753614383dba7c51991b6 Mon Sep 17 00:00:00 2001 From: obchain Date: Mon, 15 Jun 2026 17:51:53 +0530 Subject: [PATCH] Guard scatter_axis against 64-bit outputs on the GPU put_along_axis / scatter_add_axis with int64/uint64 values failed the Metal library JIT build instead of raising a clean error: packing_size in atomic.h is sizeof(uint)/sizeof(T) == 0 for 8-byte T, producing a zero-length array and a divide-by-zero in the fallback atomic union. The plain Scatter path already guards this in scatter() and Scatter::eval_gpu; ScatterAxis had no equivalent guard. Mirror it in scatter_axis() (GPU only, matching scatter()) and ScatterAxis::eval_gpu, and add a test. --- mlx/backend/metal/indexing.cpp | 6 ++++++ mlx/ops.cpp | 8 ++++++++ python/tests/test_ops.py | 14 ++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cdb4a03bb1..a3df0e6e28 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -519,6 +519,12 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + if (size_of(out.dtype()) == 8) { + std::ostringstream msg; + msg << "[ScatterAxis::eval_gpu] Does not support " << out.dtype(); + throw std::invalid_argument(msg.str()); + } + auto& src = inputs[0]; auto& idx = inputs[1]; auto& upd = inputs[2]; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..2475957047 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3529,6 +3529,14 @@ array scatter_axis( return a; } + // TODO, remove when scatter_axis supports 64-bit outputs + if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) { + std::ostringstream msg; + msg << prefix << " GPU scatter does not yet support " << a.dtype() + << " for the input or updates."; + throw std::invalid_argument(msg.str()); + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 484ae47c3a..b396ad6f0a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1353,6 +1353,20 @@ def test_put_along_axis(self): self.assertEqual(b.size, 0) self.assertEqual(b.shape, a.shape) + # 64-bit outputs are not supported on the GPU and should raise a clean + # error rather than failing the Metal JIT build. + idx = mx.array([[0], [1], [2], [3]]) + for dt in (mx.int64, mx.uint64): + x = mx.zeros((4, 8), dtype=dt) + upd = mx.ones((4, 1), dtype=dt) + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + mx.eval(mx.put_along_axis(x, idx, upd, axis=1)) + else: + out = mx.put_along_axis(x, idx, upd, axis=1) + self.assertEqual(out.dtype, dt) + mx.eval(out) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3)