Summary
The ScatterAxis ops (mx.put_along_axis, mx.scatter_add_axis) fail for 64-bit element dtypes (int64/uint64) by failing the Metal library JIT compile instead of raising a clean "unsupported dtype" error. The plain Scatter path (x[idx]=v, mx.scatter*, x.at[idx].add) already guards this and errors cleanly — ScatterAxis is just missing the equivalent guard. Both its modes fail (None/put_along_axis and Sum/scatter_add_axis).
Repro (Metal device)
import mlx.core as mx
mx.set_default_device(mx.gpu)
x = mx.zeros((4, 8), dtype=mx.int64)
idx = mx.array([[0],[1],[2],[3]])
upd = mx.ones((4, 1), dtype=mx.int64)
mx.eval(mx.put_along_axis(x, idx, upd, axis=1)) # raises
It's the value/array dtype that triggers it (int64/uint64 fail; int32/uint32/float*/bf16/bool OK). Index dtype is irrelevant.
Root cause
In mlx/backend/metal/kernels/atomic.h, packing_size = sizeof(uint)/sizeof(T) is 4/8 == 0 for 64-bit T, so uint_or_packed declares T val[0] (zero-length array → hard C++ error) and offset / packing_size becomes divide-by-zero. The fallback-atomic union is invalid for any sizeof(T) > sizeof(uint). Scatter avoids this via an 8-byte-dtype guard in scatter() (ops.cpp) and Scatter::eval_gpu (indexing.cpp); ScatterAxis has no such guard in scatter_axis() or ScatterAxis::eval_gpu.
Actual error (verbatim, M5 Pro / macOS 27.0 beta)
RuntimeError: [metal::Device] Unable to build metal library from source
mlx/backend/metal/kernels/atomic.h:168:9: error: zero-length arrays are not permitted in C++
T val[packing_size<T>];
^~~~~~~~~~~~~~~
mlx/backend/metal/kernels/atomic.h:190:21: note: in instantiation of template class '(anonymous namespace)::uint_or_packed<long>' requested here
uint_or_packed<T> expected;
^
mlx/backend/metal/kernels/atomic.h:280:3: note: in instantiation of function template specialization '(anonymous namespace)::mlx_atomic_update_and_store<long, (anonymous namespace)::__None<long>>' requested here
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
^
mlx/backend/metal/kernels/reduction/ops.h:31:5: note: in instantiation of function template specialization 'mlx_atomic_store_explicit<long, true>' requested here
mlx_atomic_store_explicit(out, val, offset);
^
mlx/backend/metal/kernels/indexing/scatter_axis.h:50:6: note: in instantiation of function template specialization 'None::atomic_update<long>' requested here
op.atomic_update(out, upd[upd_idx], out_idx);
^
mlx/backend/metal/kernels/atomic.h:186:31: warning: division by zero is undefined [-Wdivision-by-zero]
size_t pack_offset = offset / packing_size<T>;
^ ~~~~~~~~~~~~~~~
mlx/backend/metal/kernels/atomic.h:187:31: warning: remainder by zero is undefined [-Wdivision-by-zero]
size_t elem_offset = offset % packing_size<T>;
^ ~~~~~~~~~~~~~~~
Expected
Either support int64/uint64 for these ops on Metal, or (minimum) reject 8-byte element dtypes in scatter_axis() with the same clean ValueError the Scatter path already raises, before attempting a Metal build.
Environment
MLX 0.32.0.dev (source, base a6ec712; atomic.h unmodified vs upstream main) · macOS 27.0 beta (26A5353q) · Apple M5 Pro · Metal 32023.917, target air64-apple-darwin27.0.0. Workaround (total): keep scattered arrays at int32. Happy to test a patch on this setup.
Summary
The ScatterAxis ops (mx.put_along_axis, mx.scatter_add_axis) fail for 64-bit element dtypes (int64/uint64) by failing the Metal library JIT compile instead of raising a clean "unsupported dtype" error. The plain Scatter path (x[idx]=v, mx.scatter*, x.at[idx].add) already guards this and errors cleanly — ScatterAxis is just missing the equivalent guard. Both its modes fail (None/put_along_axis and Sum/scatter_add_axis).
Repro (Metal device)
It's the value/array dtype that triggers it (int64/uint64 fail; int32/uint32/float*/bf16/bool OK). Index dtype is irrelevant.
Root cause
In mlx/backend/metal/kernels/atomic.h, packing_size = sizeof(uint)/sizeof(T) is 4/8 == 0 for 64-bit T, so uint_or_packed declares T val[0] (zero-length array → hard C++ error) and offset / packing_size becomes divide-by-zero. The fallback-atomic union is invalid for any sizeof(T) > sizeof(uint). Scatter avoids this via an 8-byte-dtype guard in scatter() (ops.cpp) and Scatter::eval_gpu (indexing.cpp); ScatterAxis has no such guard in scatter_axis() or ScatterAxis::eval_gpu.
Actual error (verbatim, M5 Pro / macOS 27.0 beta)
Expected
Either support int64/uint64 for these ops on Metal, or (minimum) reject 8-byte element dtypes in scatter_axis() with the same clean ValueError the Scatter path already raises, before attempting a Metal build.
Environment
MLX 0.32.0.dev (source, base a6ec712; atomic.h unmodified vs upstream main) · macOS 27.0 beta (26A5353q) · Apple M5 Pro · Metal 32023.917, target air64-apple-darwin27.0.0. Workaround (total): keep scattered arrays at int32. Happy to test a patch on this setup.