diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cdb4a03bb1..a3df0e6e28 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -519,6 +519,12 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + if (size_of(out.dtype()) == 8) { + std::ostringstream msg; + msg << "[ScatterAxis::eval_gpu] Does not support " << out.dtype(); + throw std::invalid_argument(msg.str()); + } + auto& src = inputs[0]; auto& idx = inputs[1]; auto& upd = inputs[2]; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..2475957047 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3529,6 +3529,14 @@ array scatter_axis( return a; } + // TODO, remove when scatter_axis supports 64-bit outputs + if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) { + std::ostringstream msg; + msg << prefix << " GPU scatter does not yet support " << a.dtype() + << " for the input or updates."; + throw std::invalid_argument(msg.str()); + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 484ae47c3a..b396ad6f0a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1353,6 +1353,20 @@ def test_put_along_axis(self): self.assertEqual(b.size, 0) self.assertEqual(b.shape, a.shape) + # 64-bit outputs are not supported on the GPU and should raise a clean + # error rather than failing the Metal JIT build. + idx = mx.array([[0], [1], [2], [3]]) + for dt in (mx.int64, mx.uint64): + x = mx.zeros((4, 8), dtype=dt) + upd = mx.ones((4, 1), dtype=dt) + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + mx.eval(mx.put_along_axis(x, idx, upd, axis=1)) + else: + out = mx.put_along_axis(x, idx, upd, axis=1) + self.assertEqual(out.dtype, dt) + mx.eval(out) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3)