Skip to content

[BUG] int64/uint64 put_along_axis / scatter_add_axis crashes the Metal JIT build #3690

@programVeins

Description

@programVeins

Summary

The ScatterAxis ops (mx.put_along_axis, mx.scatter_add_axis) fail for 64-bit element dtypes (int64/uint64) by failing the Metal library JIT compile instead of raising a clean "unsupported dtype" error. The plain Scatter path (x[idx]=v, mx.scatter*, x.at[idx].add) already guards this and errors cleanly — ScatterAxis is just missing the equivalent guard. Both its modes fail (None/put_along_axis and Sum/scatter_add_axis).

Repro (Metal device)

import mlx.core as mx
mx.set_default_device(mx.gpu)
x   = mx.zeros((4, 8), dtype=mx.int64)
idx = mx.array([[0],[1],[2],[3]])
upd = mx.ones((4, 1), dtype=mx.int64)
mx.eval(mx.put_along_axis(x, idx, upd, axis=1))   # raises

It's the value/array dtype that triggers it (int64/uint64 fail; int32/uint32/float*/bf16/bool OK). Index dtype is irrelevant.

Root cause

In mlx/backend/metal/kernels/atomic.h, packing_size = sizeof(uint)/sizeof(T) is 4/8 == 0 for 64-bit T, so uint_or_packed declares T val[0] (zero-length array → hard C++ error) and offset / packing_size becomes divide-by-zero. The fallback-atomic union is invalid for any sizeof(T) > sizeof(uint). Scatter avoids this via an 8-byte-dtype guard in scatter() (ops.cpp) and Scatter::eval_gpu (indexing.cpp); ScatterAxis has no such guard in scatter_axis() or ScatterAxis::eval_gpu.

Actual error (verbatim, M5 Pro / macOS 27.0 beta)

RuntimeError: [metal::Device] Unable to build metal library from source
mlx/backend/metal/kernels/atomic.h:168:9: error: zero-length arrays are not permitted in C++
  T val[packing_size<T>];
        ^~~~~~~~~~~~~~~
mlx/backend/metal/kernels/atomic.h:190:21: note: in instantiation of template class '(anonymous namespace)::uint_or_packed<long>' requested here
  uint_or_packed<T> expected;
                    ^
mlx/backend/metal/kernels/atomic.h:280:3: note: in instantiation of function template specialization '(anonymous namespace)::mlx_atomic_update_and_store<long, (anonymous namespace)::__None<long>>' requested here
  mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
  ^
mlx/backend/metal/kernels/reduction/ops.h:31:5: note: in instantiation of function template specialization 'mlx_atomic_store_explicit<long, true>' requested here
    mlx_atomic_store_explicit(out, val, offset);
    ^
mlx/backend/metal/kernels/indexing/scatter_axis.h:50:6: note: in instantiation of function template specialization 'None::atomic_update<long>' requested here
  op.atomic_update(out, upd[upd_idx], out_idx);
     ^
mlx/backend/metal/kernels/atomic.h:186:31: warning: division by zero is undefined [-Wdivision-by-zero]
  size_t pack_offset = offset / packing_size<T>;
                              ^ ~~~~~~~~~~~~~~~
mlx/backend/metal/kernels/atomic.h:187:31: warning: remainder by zero is undefined [-Wdivision-by-zero]
  size_t elem_offset = offset % packing_size<T>;
                              ^ ~~~~~~~~~~~~~~~

Expected

Either support int64/uint64 for these ops on Metal, or (minimum) reject 8-byte element dtypes in scatter_axis() with the same clean ValueError the Scatter path already raises, before attempting a Metal build.

Environment

MLX 0.32.0.dev (source, base a6ec712; atomic.h unmodified vs upstream main) · macOS 27.0 beta (26A5353q) · Apple M5 Pro · Metal 32023.917, target air64-apple-darwin27.0.0. Workaround (total): keep scattered arrays at int32. Happy to test a patch on this setup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions