Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlx/backend/metal/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,12 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}

void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
if (size_of(out.dtype()) == 8) {
std::ostringstream msg;
msg << "[ScatterAxis::eval_gpu] Does not support " << out.dtype();
throw std::invalid_argument(msg.str());
}

auto& src = inputs[0];
auto& idx = inputs[1];
auto& upd = inputs[2];
Expand Down
8 changes: 8 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3529,6 +3529,14 @@ array scatter_axis(
return a;
}

// TODO, remove when scatter_axis supports 64-bit outputs
if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check in metal/indexing.cpp along is enough, and it actually works in the cuda backend, the test should also be updated for metal only

std::ostringstream msg;
msg << prefix << " GPU scatter does not yet support " << a.dtype()
<< " for the input or updates.";
throw std::invalid_argument(msg.str());
}

auto upd = astype(values, a.dtype(), s);

// Squeeze leading singletons out of update
Expand Down
14 changes: 14 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,20 @@ def test_put_along_axis(self):
self.assertEqual(b.size, 0)
self.assertEqual(b.shape, a.shape)

# 64-bit outputs are not supported on the GPU and should raise a clean
# error rather than failing the Metal JIT build.
idx = mx.array([[0], [1], [2], [3]])
for dt in (mx.int64, mx.uint64):
x = mx.zeros((4, 8), dtype=dt)
upd = mx.ones((4, 1), dtype=dt)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
mx.eval(mx.put_along_axis(x, idx, upd, axis=1))
else:
out = mx.put_along_axis(x, idx, upd, axis=1)
self.assertEqual(out.dtype, dt)
mx.eval(out)

def test_split(self):
a = mx.array([1, 2, 3])
splits = mx.split(a, 3)
Expand Down