From 7835acef9bd9e6039771ff83e24b5e149d18c5c4 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Fri, 19 Jun 2026 18:45:41 +0300 Subject: [PATCH 1/9] Remove Vulkan host sync fallbacks --- mlx/backend/vulkan/arange.cpp | 34 -- mlx/backend/vulkan/gather.cpp | 264 +++------- mlx/backend/vulkan/kernels/random_bits.comp | 18 +- mlx/backend/vulkan/random.cpp | 113 ++--- mlx/backend/vulkan/scatter.cpp | 509 ++------------------ 5 files changed, 140 insertions(+), 798 deletions(-) diff --git a/mlx/backend/vulkan/arange.cpp b/mlx/backend/vulkan/arange.cpp index 38a0fb9083..72e08a7fba 100644 --- a/mlx/backend/vulkan/arange.cpp +++ b/mlx/backend/vulkan/arange.cpp @@ -1,9 +1,5 @@ // Copyright © 2024 Apple Inc. -#include -#include - -#include "mlx/backend/gpu/copy.h" #include "mlx/backend/vulkan/primitives_utils.h" #include "mlx/backend/vulkan/vulkan.h" #include "mlx/dtype.h" @@ -49,36 +45,6 @@ bool try_eval_arange_vulkan( return false; } - if ((out.dtype() == float16 || out.dtype() == bfloat16) && - start == std::trunc(start) && step == std::trunc(step)) { - // Match CPU low-precision arange semantics by advancing in the target dtype - // instead of recomputing each element from float32 math in the shader. - const auto n = out.size(); - out.set_data(allocator::malloc(out.nbytes())); - if (n == 0) { - return true; - } - if (out.dtype() == float16) { - auto* dst = out.data(); - float16_t value(static_cast(start)); - const float16_t step_value(static_cast(step)); - for (size_t i = 0; i < n; ++i) { - dst[i] = value; - value = float16_t(static_cast(value) + static_cast(step_value)); - } - return true; - } - - auto* dst = out.data(); - bfloat16_t value(static_cast(start)); - const bfloat16_t step_value(static_cast(step)); - for (size_t i = 0; i < n; ++i) { - dst[i] = value; - value = bfloat16_t(static_cast(value) + static_cast(step_value)); - } - return true; - } - auto shader_id = arange_shader_id(out.dtype()); if (!shader_id.has_value()) { return false; diff --git a/mlx/backend/vulkan/gather.cpp b/mlx/backend/vulkan/gather.cpp index 67e896407a..d7ea9df47a 100644 --- a/mlx/backend/vulkan/gather.cpp +++ b/mlx/backend/vulkan/gather.cpp @@ -40,17 +40,6 @@ bool ensure_vulkan_storage(array& arr, Stream s) { return vulkan::is_vulkan_storage_array(arr); } -array ensure_host_readable_row_contiguous(array arr, Stream s) { - if (arr.has_primitive()) { - arr.eval(); - } - if (needs_row_contiguous(arr)) { - arr = contiguous_copy_gpu(arr, s); - } - arr.wait(); - return arr; -} - std::pair make_output_work(array& out) { const bool staged_output = needs_row_contiguous(out); array out_work = @@ -69,6 +58,12 @@ checked_shape_product(const array& arr, int begin, int end, const char* label) { return product; } +bool is_host_readable_index_constant(const array& idx) { + auto data = idx.data_shared_ptr(); + return !idx.has_primitive() && data != nullptr && data->buffer.ptr() != nullptr && + !vulkan::is_vulkan_buffer(data->buffer); +} + std::string build_complex_gather_axis_shader() { std::ostringstream os; os << vulkan::emit_dynamic_shader_preamble(complex64, complex64, false); @@ -257,7 +252,7 @@ bool try_eval_i64_gather_axis_vulkan( } std::optional scalar_index_value(const array& idx) { - if (idx.ndim() != 0) { + if (idx.ndim() != 0 || !is_host_readable_index_constant(idx)) { return std::nullopt; } switch (idx.dtype()) { @@ -280,7 +275,7 @@ std::optional scalar_index_value(const array& idx) { } std::optional singleton_index_value(const array& idx) { - if (idx.size() != 1) { + if (idx.size() != 1 || !is_host_readable_index_constant(idx)) { return std::nullopt; } switch (idx.dtype()) { @@ -314,44 +309,6 @@ int64_t normalize_gather_index(int64_t idx, int64_t axis_size) { return idx; } -int64_t read_contiguous_index(const array& idx, int i) { - switch (idx.dtype()) { - case int32: - return idx.data()[i]; - case int64: - return idx.data()[i]; - case uint32: - return idx.data()[i]; - case uint64: { - auto val = idx.data()[i]; - if (val > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("uint64 index exceeds max int64_t value"); - } - return static_cast(val); - } - default: - throw std::runtime_error("Unsupported index dtype for Vulkan gather."); - } -} - -bool is_full_range_index_for_axis( - const array& idx, - int64_t axis_size, - Stream s) { - if (axis_size <= 0 || idx.size() == 0 || (idx.size() % axis_size) != 0) { - return false; - } - auto flat_idx = ensure_row_contiguous( - reshape(idx, {static_cast(idx.size())}, s), s); - flat_idx.eval(); - for (int i = 0; i < flat_idx.size(); ++i) { - if (read_contiguous_index(flat_idx, i) != (i % axis_size)) { - return false; - } - } - return true; -} - constexpr uint32_t kMaxGatherPushConstants = 128; std::string build_generic_gather_shader( @@ -670,83 +627,8 @@ bool try_eval_gather_vulkan( return true; } - if (trace_fallback_enabled()) { - trace_fallback("generic_gather_gpu_unavailable fallback=host_loop"); - } - - std::vector flat_indices; - flat_indices.reserve(inputs.size() - 1); - for (int i = 1; i < inputs.size(); ++i) { - flat_indices.push_back(ensure_host_readable_row_contiguous( - reshape(inputs[i], {static_cast(index_count)}, s), s)); - } - - auto [out_work, staged_output] = make_output_work(out); - if (out_work.size() == 0) { - if (staged_output) { - copy_gpu(out_work, out, CopyType::GeneralGeneral, s); - } - return true; - } - - Strides out_slice_strides( - out_work.strides().begin() + idx_ndim, out_work.strides().end()); - size_t out_slice_elems = 1; - for (auto dim : out_slice_shape) { - out_slice_elems *= static_cast(dim); - } - auto [out_slice_data_size, out_slice_row_contig, out_slice_col_contig] = - check_contiguity(out_slice_shape, out_slice_strides); - array::Flags out_slice_flags = { - out_slice_data_size == out_slice_elems, - out_slice_row_contig, - out_slice_col_contig}; - - Strides index_shape_strides(idx_ndim, 1); - for (int i = idx_ndim - 2; i >= 0; --i) { - index_shape_strides[i] = - index_shape_strides[i + 1] * inputs[1].shape(i + 1); - } - - for (uint32_t i = 0; i < index_count; ++i) { - Shape start(src_input.ndim(), 0); - Shape stop = slice_sizes; - Shape unit_strides(src_input.ndim(), 1); - for (int j = 0; j < norm_axes.size(); ++j) { - const int axis = norm_axes[j]; - start[axis] = normalize_gather_index( - read_contiguous_index(flat_indices[j], i), src_input.shape(axis)); - stop[axis] += start[axis]; - if (stop[axis] > src_input.shape(axis)) { - return false; - } - } - - array gathered = slice(src_input, start, stop, unit_strides, s); - - int64_t out_offset = 0; - size_t remainder = i; - for (int d = 0; d < idx_ndim; ++d) { - const size_t coord = remainder / index_shape_strides[d]; - remainder %= index_shape_strides[d]; - out_offset += coord * out_work.strides(d); - } - - array out_slice(out_slice_shape, out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out_work, - out_slice_strides, - out_slice_flags, - out_slice_data_size, - out_offset); - out_slice.set_status(array::Status::available); - copy_gpu_inplace(gathered, out_slice, CopyType::GeneralGeneral, s); - } - - if (staged_output) { - copy_gpu(out_work, out, CopyType::GeneralGeneral, s); - } - return true; + trace_vulkan_unsupported("Gather", "generic gather GPU dispatch failed"); + return false; } if (axes.size() == 2) { @@ -803,78 +685,8 @@ bool try_eval_gather_vulkan( return true; } - idx0 = ensure_host_readable_row_contiguous( - reshape(idx0, {static_cast(idx0.size())}, s), s); - idx1 = ensure_host_readable_row_contiguous( - reshape(idx1, {static_cast(idx1.size())}, s), s); - - auto [out_work, staged_output] = make_output_work(out); - if (out_work.size() == 0) { - if (staged_output) { - copy_gpu(out_work, out, CopyType::GeneralGeneral, s); - } - return true; - } - - Strides out_slice_strides( - out_work.strides().begin() + idx_ndim, out_work.strides().end()); - size_t out_slice_elems = 1; - for (auto dim : out_slice_shape) { - out_slice_elems *= static_cast(dim); - } - auto [out_slice_data_size, out_slice_row_contig, out_slice_col_contig] = - check_contiguity(out_slice_shape, out_slice_strides); - array::Flags out_slice_flags = { - out_slice_data_size == out_slice_elems, - out_slice_row_contig, - out_slice_col_contig}; - - Strides index_shape_strides(idx_ndim, 1); - for (int i = idx_ndim - 2; i >= 0; --i) { - index_shape_strides[i] = - index_shape_strides[i + 1] * inputs[1].shape(i + 1); - } - - for (uint32_t i = 0; i < index_count; ++i) { - Shape start(src_input.ndim(), 0); - Shape stop = slice_sizes; - Shape unit_strides(src_input.ndim(), 1); - start[axis0] = normalize_gather_index( - read_contiguous_index(idx0, i), src_input.shape(axis0)); - start[axis1] = normalize_gather_index( - read_contiguous_index(idx1, i), src_input.shape(axis1)); - stop[axis0] += start[axis0]; - stop[axis1] += start[axis1]; - if (stop[axis0] > src_input.shape(axis0) || - stop[axis1] > src_input.shape(axis1)) { - return false; - } - - array gathered = slice(src_input, start, stop, unit_strides, s); - - int64_t out_offset = 0; - size_t remainder = i; - for (int d = 0; d < idx_ndim; ++d) { - const size_t coord = remainder / index_shape_strides[d]; - remainder %= index_shape_strides[d]; - out_offset += coord * out_work.strides(d); - } - - array out_slice(out_slice_shape, out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out_work, - out_slice_strides, - out_slice_flags, - out_slice_data_size, - out_offset); - out_slice.set_status(array::Status::available); - copy_gpu_inplace(gathered, out_slice, CopyType::GeneralGeneral, s); - } - - if (staged_output) { - copy_gpu(out_work, out, CopyType::GeneralGeneral, s); - } - return true; + trace_vulkan_unsupported("Gather", "pair gather GPU dispatch failed"); + return false; } array src = ensure_row_contiguous(src_input, s); @@ -950,6 +762,10 @@ bool try_eval_gather_vulkan( trace_vulkan_unsupported("Gather", "axis is out of range"); return false; } + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return true; + } if (auto scalar_index = scalar_index_value(idx); scalar_index.has_value()) { Shape start(src_input.ndim(), 0); Shape stop = slice_sizes; @@ -1000,17 +816,61 @@ bool try_eval_gather_vulkan( s); return true; } + + bool take_like_single_axis = true; for (int i = 0; i < src_input.ndim(); ++i) { const int64_t expected = (i == axis) ? 1 : src_input.shape(i); if (slice_sizes[i] != expected) { - trace_vulkan_unsupported( - "Gather", "only take-like single-axis gathers are supported"); + take_like_single_axis = false; + break; + } + } + if (!take_like_single_axis) { + const int idx_ndim = idx.ndim(); + if (out.ndim() != idx_ndim + src_input.ndim()) { return false; } + Shape out_slice_shape(out.shape().begin() + idx_ndim, out.shape().end()); + if (out_slice_shape != slice_sizes) { + return false; + } + std::vector generic_inputs = {src_input, idx}; + std::vector norm_axes = {axis}; + if (try_dispatch_generic_gather( + generic_inputs, + norm_axes, + slice_sizes, + idx_ndim, + checked_u32_size(idx.size(), "gather_single_axis index_count"), + out, + s)) { + return true; + } + trace_vulkan_unsupported( + "Gather", "single-axis generic gather GPU dispatch failed"); + return false; } const auto shader_id = gather_shader_id(src_input.dtype(), idx.dtype()); if (!shader_id.has_value()) { + const int idx_ndim = idx.ndim(); + if (out.ndim() == idx_ndim + src_input.ndim()) { + Shape out_slice_shape(out.shape().begin() + idx_ndim, out.shape().end()); + if (out_slice_shape == slice_sizes) { + std::vector generic_inputs = {src_input, idx}; + std::vector norm_axes = {axis}; + if (try_dispatch_generic_gather( + generic_inputs, + norm_axes, + slice_sizes, + idx_ndim, + checked_u32_size(idx.size(), "gather_take_generic index_count"), + out, + s)) { + return true; + } + } + } trace_vulkan_unsupported( "Gather", "value/index dtype combination is not supported by Vulkan gather"); diff --git a/mlx/backend/vulkan/kernels/random_bits.comp b/mlx/backend/vulkan/kernels/random_bits.comp index 9ef5719f09..114af0197a 100644 --- a/mlx/backend/vulkan/kernels/random_bits.comp +++ b/mlx/backend/vulkan/kernels/random_bits.comp @@ -2,6 +2,8 @@ #include "types.glsl" +#extension GL_EXT_scalar_block_layout : require + layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Params { @@ -12,7 +14,7 @@ layout(push_constant) uniform Params { } p; layout(binding = 0) readonly buffer Keys { uint keys[]; }; -layout(binding = 1) writeonly buffer Out { uint out_data[]; }; +layout(scalar, binding = 1) writeonly buffer Out { uint8_t out_data[]; }; const uint ROTATIONS[2][4] = uint[2][4]( uint[4](13, 15, 26, 6), @@ -68,18 +70,28 @@ void main() { bool drop_last = !even && (half_idx == half_size); uvec2 count = uvec2(half_idx, drop_last ? 0 : half_idx + grid_y); uvec2 bits = threefry2x32_hash(key, count); + + uint key_byte_offset = key_idx * p.bytes_per_key; + + #define WRITE_WORD(word_idx, value) { \ + uint byte_offset = (word_idx) * 4; \ + if (byte_offset + 0 < p.bytes_per_key) out_data[key_byte_offset + byte_offset + 0] = uint8_t((value) & 0xffu); \ + if (byte_offset + 1 < p.bytes_per_key) out_data[key_byte_offset + byte_offset + 1] = uint8_t(((value) >> 8) & 0xffu); \ + if (byte_offset + 2 < p.bytes_per_key) out_data[key_byte_offset + byte_offset + 2] = uint8_t(((value) >> 16) & 0xffu); \ + if (byte_offset + 3 < p.bytes_per_key) out_data[key_byte_offset + byte_offset + 3] = uint8_t(((value) >> 24) & 0xffu); \ + } // Write first 4 bytes. uint idx1 = out_offset + half_idx; if (idx1 < out_offset + p.out_skip) { - out_data[idx1] = bits.x; + WRITE_WORD(half_idx, bits.x); } // Write second 4 bytes (if not dropping last). if (!drop_last) { uint idx2 = out_offset + half_idx + grid_y; if (idx2 < out_offset + p.out_skip) { - out_data[idx2] = bits.y; + WRITE_WORD(half_idx + grid_y, bits.y); } } } diff --git a/mlx/backend/vulkan/random.cpp b/mlx/backend/vulkan/random.cpp index 430bc4c287..88a57561bf 100644 --- a/mlx/backend/vulkan/random.cpp +++ b/mlx/backend/vulkan/random.cpp @@ -1,8 +1,7 @@ // Copyright © 2024 Apple Inc. -#include +#include -#include "mlx/backend/cpu/threefry.h" #include "mlx/backend/vulkan/primitives_utils.h" namespace mlx::core { @@ -30,98 +29,40 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { "RandomBits failed on Vulkan (only uint32 keys supported)."); } - if (width_ == 4 && bytes_per_key % 4 == 0) { - if (!keys.flags().contiguous || keys.offset() != 0 || keys.strides().back() != 1) { - keys = contiguous_copy_gpu(keys, stream()); - } - keys.wait(); - - auto* kptr = keys.data(); - std::vector host_out(out.size()); - size_t out_skip = bytes_per_key / 4; - size_t half_size = out_skip / 2; - bool even = out_skip % 2 == 0; - for (size_t i = 0; i < num_keys; ++i) { - auto key = std::make_pair(kptr[2 * i], kptr[2 * i + 1]); - auto* dst = host_out.data() + i * out_skip; - - std::pair count{0, half_size + !even}; - for (; count.first + 1 < half_size; count.first++, count.second++) { - std::tie(dst[count.first], dst[count.second]) = - random::threefry2x32_hash(key, count); - } - if (count.first < half_size) { - auto rb = random::threefry2x32_hash(key, count); - dst[count.first++] = rb.first; - dst[count.second] = rb.second; - } - if (!even) { - count.second = 0; - dst[half_size] = random::threefry2x32_hash(key, count).first; - } - } - copy_gpu( - array(host_out.begin(), out.shape(), uint32), - out, - CopyType::GeneralGeneral, - stream()); - return; - } - - if (!keys.flags().contiguous || keys.offset() != 0 || keys.strides().back() != 1) { + if (!keys.flags().contiguous || keys.offset() != 0 || + keys.strides().back() != 1) { keys = contiguous_copy_gpu(keys, stream()); } - keys.wait(); - - auto* kptr = keys.data(); - auto host_out = std::make_shared>(out.nbytes()); - auto* cptr = host_out->data(); - auto copy_word = [&](char* dst, size_t word_idx, uint32_t v) { - const size_t byte_offset = 4 * word_idx; - if (byte_offset + 4 <= bytes_per_key) { - std::copy( - reinterpret_cast(&v), - reinterpret_cast(&v) + 4, - dst + byte_offset); - } else { - std::copy( - reinterpret_cast(&v), - reinterpret_cast(&v) + (bytes_per_key - byte_offset), - dst + byte_offset); - } - }; + out.set_data(allocator::malloc(out.nbytes())); size_t out_skip = (bytes_per_key + 4 - 1) / 4; size_t half_size = out_skip / 2; - bool even = out_skip % 2 == 0; - for (size_t i = 0; i < num_keys; ++i, cptr += bytes_per_key) { - auto key = std::make_pair(kptr[2 * i], kptr[2 * i + 1]); + bool odd = out_skip % 2 != 0; - std::pair count{0, half_size + !even}; - for (; count.first + 1 < half_size; count.first++, count.second++) { - auto rb = random::threefry2x32_hash(key, count); - copy_word(cptr, count.first, rb.first); - copy_word(cptr, count.second, rb.second); - } - if (count.first < half_size) { - auto rb = random::threefry2x32_hash(key, count); - copy_word(cptr, count.first++, rb.first); - copy_word(cptr, count.second, rb.second); - } - if (!even) { - count.second = 0; - copy_word(cptr, half_size, random::threefry2x32_hash(key, count).first); - } + if (num_keys > std::numeric_limits::max() || + bytes_per_key > std::numeric_limits::max() || + out_skip > std::numeric_limits::max()) { + throw std::runtime_error("RandomBits failed on Vulkan (shape too large)."); } - copy_gpu( - array( - static_cast(host_out->data()), - out.shape(), - out.dtype(), - [host_out](void*) {}), + + vulkan::RandomBitsPushConstants push_constants{ + static_cast(num_keys), + static_cast(bytes_per_key), + odd ? 1u : 0u, + static_cast(out_skip)}; + + auto command_buffer = vulkan::begin_command_recording(stream().index); + vulkan::dispatch_random_bits_op( + keys, out, - CopyType::GeneralGeneral, - stream()); + vulkan::StaticShaderId::random_bits_f32, + command_buffer, + stream(), + push_constants, + {static_cast(num_keys), + static_cast(half_size + odd), + 1}); + vulkan::end_command_recording(stream().index); } } // namespace mlx::core diff --git a/mlx/backend/vulkan/scatter.cpp b/mlx/backend/vulkan/scatter.cpp index c65de02c79..dd652964bc 100644 --- a/mlx/backend/vulkan/scatter.cpp +++ b/mlx/backend/vulkan/scatter.cpp @@ -1,7 +1,5 @@ // Copyright © 2024 Apple Inc. -#include -#include #include #include "mlx/backend/common/slicing.h" @@ -40,22 +38,6 @@ bool ensure_vulkan_storage(array& arr, Stream s) { return vulkan::is_vulkan_storage_array(arr); } -array ensure_host_readable_row_contiguous(array arr, Stream s) { - if (arr.has_primitive()) { - arr.eval(); - } - if (arr.data_size() == 1 && arr.size() != 1) { - array materialized(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, materialized, CopyType::Scalar, s); - arr = std::move(materialized); - } - if (needs_row_contiguous(arr)) { - arr = contiguous_copy_gpu(arr, s); - } - arr.wait(); - return arr; -} - CopyType source_copy_type(const array& src) { if (src.data_size() == 1) { return CopyType::Scalar; @@ -84,104 +66,6 @@ checked_shape_product(const array& arr, int begin, int end, const char* label) { return product; } -std::optional scalar_index_value(const array& idx) { - if (idx.ndim() != 0) { - return std::nullopt; - } - switch (idx.dtype()) { - case int32: - return idx.item(); - case int64: - return idx.item(); - case uint32: - return static_cast(idx.item()); - case uint64: { - auto value = idx.item(); - if (value > static_cast(std::numeric_limits::max())) { - return std::nullopt; - } - return static_cast(value); - } - default: - return std::nullopt; - } -} - -std::optional singleton_index_value(const array& idx) { - if (idx.size() != 1) { - return std::nullopt; - } - auto scalar = idx.ndim() == 0 - ? idx - : slice(idx, Shape(idx.ndim(), 0), Shape(idx.ndim(), 1)); - switch (idx.dtype()) { - case int32: - return scalar.item(); - case int64: - return scalar.item(); - case uint32: - return static_cast(scalar.item()); - case uint64: { - auto value = scalar.item(); - if (value > static_cast(std::numeric_limits::max())) { - return std::nullopt; - } - return static_cast(value); - } - default: - return std::nullopt; - } -} - -int64_t normalize_scatter_index(int64_t idx, int64_t axis_size) { - if (idx < 0) { - idx += axis_size; - } - if (idx < 0 || idx >= axis_size) { - throw std::out_of_range( - "scatter index " + std::to_string(idx) + " out of bounds " + - std::to_string(axis_size)); - } - return idx; -} - -int64_t read_flat_index_item(const array& idx, int i, Stream s) { - auto scalar = slice( - idx, {static_cast(i)}, {static_cast(i + 1)}, s); - switch (idx.dtype()) { - case int32: - return scalar.item(); - case int64: - return scalar.item(); - case uint32: - return static_cast(scalar.item()); - case uint64: - return static_cast(scalar.item()); - default: - throw std::runtime_error("Unsupported index dtype for Vulkan scatter."); - } -} - -int64_t read_contiguous_index(const array& idx, int i) { - switch (idx.dtype()) { - case int32: - return idx.data()[i]; - case int64: - return idx.data()[i]; - case uint32: - return idx.data()[i]; - case uint64: { - auto val = idx.data()[i]; - if (val > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("uint64 index exceeds max int64_t value"); - } - return static_cast(val); - } - default: - throw std::runtime_error("Unsupported index dtype for Vulkan scatter."); - } -} - constexpr uint32_t kMaxScatterPushConstants = 128; bool supports_dynamic_scatter_sum_dtype(Dtype dtype) { @@ -195,6 +79,9 @@ std::string build_generic_scatter_shader( int nidx, Scatter::ReduceType reduce_type) { std::ostringstream os; + const bool use_float_atomic_cas = value_dtype == float32 && + (reduce_type == Scatter::Prod || reduce_type == Scatter::Max || + reduce_type == Scatter::Min); os << "#version 450\n"; os << "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n"; @@ -216,8 +103,7 @@ std::string build_generic_scatter_shader( os << "#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n"; os << "#extension GL_EXT_shader_8bit_storage : require\n"; } - if ((reduce_type == Scatter::Sum || reduce_type == Scatter::Prod) && - value_dtype == float32) { + if (reduce_type == Scatter::Sum && value_dtype == float32) { os << "#extension GL_EXT_shader_atomic_float : require\n"; } @@ -238,7 +124,9 @@ std::string build_generic_scatter_shader( os << "#define INDEX_TYPE int\n"; } - os << "#define VALUE_TYPE " << vulkan::dtype_to_glsl_storage_type(value_dtype) + os << "#define VALUE_TYPE " + << (use_float_atomic_cas ? "uint" + : vulkan::dtype_to_glsl_storage_type(value_dtype)) << "\n"; os << "\nlayout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;\n\n"; @@ -270,6 +158,27 @@ std::string build_generic_scatter_shader( os << "#endif\n"; os << "}\n\n"; + if (use_float_atomic_cas) { + os << "float read_update(uint idx) { return uintBitsToFloat(upd_data[idx]); }\n"; + os << "void atomic_reduce(uint dst_offset, float value) {\n"; + os << " uint old_bits = out_data[dst_offset];\n"; + os << " while (true) {\n"; + os << " float old_value = uintBitsToFloat(old_bits);\n"; + if (reduce_type == Scatter::Prod) { + os << " float new_value = old_value * value;\n"; + } else if (reduce_type == Scatter::Max) { + os << " float new_value = max(old_value, value);\n"; + } else { + os << " float new_value = min(old_value, value);\n"; + } + os << " uint new_bits = floatBitsToUint(new_value);\n"; + os << " uint prev_bits = atomicCompSwap(out_data[dst_offset], old_bits, new_bits);\n"; + os << " if (prev_bits == old_bits) break;\n"; + os << " old_bits = prev_bits;\n"; + os << " }\n"; + os << "}\n\n"; + } + os << "void main() {\n"; os << " uint linear_idx = gl_GlobalInvocationID.x;\n"; os << " if (linear_idx >= p.ne) return;\n\n"; @@ -291,6 +200,8 @@ std::string build_generic_scatter_shader( if (reduce_type == Scatter::Sum) { os << "\n atomicAdd(out_data[dst_offset], upd_data[linear_idx]);\n"; + } else if (use_float_atomic_cas) { + os << "\n atomic_reduce(dst_offset, read_update(linear_idx));\n"; } else if (reduce_type == Scatter::Prod) { os << "\n out_data[dst_offset] *= upd_data[linear_idx];\n"; } else if (reduce_type == Scatter::Max) { @@ -444,142 +355,6 @@ bool try_dispatch_generic_scatter( } } -template -bool try_host_scatter_sum_single_axis_typed( - const array& src, - const array& idx, - const array& upd, - array& out, - int axis, - const Shape& update_shape, - uint32_t slice_elems, - Stream s) { - const uint32_t index_count = - checked_u32_size(idx.size(), "scatter_host index_count"); - array src_host = ensure_host_readable_row_contiguous(src, s); - array idx_host = ensure_host_readable_row_contiguous( - reshape(idx, {static_cast(index_count)}, s), s); - array upd_host = ensure_host_readable_row_contiguous( - add(upd, array(0.0f, upd.dtype()), s), s); - - auto src_strides = make_contiguous_strides(src.shape()); - auto update_strides = make_contiguous_strides(update_shape); - std::vector result(src.size()); - std::copy( - src_host.data(), src_host.data() + src.size(), result.begin()); - - for (uint32_t i = 0; i < index_count; ++i) { - auto normalized_index = normalize_scatter_index( - read_contiguous_index(idx_host, i), src.shape(axis)); - const size_t update_base = static_cast(i) * slice_elems; - for (uint32_t linear = 0; linear < slice_elems; ++linear) { - size_t remainder = linear; - size_t dst_offset = 0; - for (int d = 0; d < src.ndim(); ++d) { - size_t coord = 0; - if (!update_shape.empty()) { - coord = remainder / update_strides[d]; - remainder %= update_strides[d]; - } - if (d == axis) { - coord += normalized_index; - } - dst_offset += coord * src_strides[d]; - } - result[dst_offset] += upd_host.data()[update_base + linear]; - } - } - - array host_result(result.begin(), src.shape(), src.dtype()); - out.copy_shared_buffer(host_result); - return true; -} - -bool try_host_scatter_sum_single_axis( - const array& src, - const array& idx, - const array& upd, - array& out, - int axis, - const Shape& update_shape, - uint32_t slice_elems, - Stream s) { - if (src.dtype() != out.dtype() || upd.dtype() != src.dtype()) { - return false; - } - switch (src.dtype()) { - case float32: - return try_host_scatter_sum_single_axis_typed( - src, idx, upd, out, axis, update_shape, slice_elems, s); - case int32: - return try_host_scatter_sum_single_axis_typed( - src, idx, upd, out, axis, update_shape, slice_elems, s); - case int64: - return try_host_scatter_sum_single_axis_typed( - src, idx, upd, out, axis, update_shape, slice_elems, s); - case uint32: - return try_host_scatter_sum_single_axis_typed( - src, idx, upd, out, axis, update_shape, slice_elems, s); - case uint64: - return try_host_scatter_sum_single_axis_typed( - src, idx, upd, out, axis, update_shape, slice_elems, s); - default: - return false; - } -} - -bool try_host_scatter_none_single_axis( - const array& src, - int64_t normalized_index, - const array& upd, - array& out, - int axis, - const Shape& update_shape, - uint32_t slice_elems, - Stream s) { - array src_host = ensure_host_readable_row_contiguous(src, s); - array upd_host = - ensure_host_readable_row_contiguous(reshape(upd, update_shape, s), s); - - auto src_strides = make_contiguous_strides(src.shape()); - auto update_strides = make_contiguous_strides(update_shape); - const size_t item_size = size_of(src.dtype()); - std::vector result(src.nbytes()); - std::memcpy(result.data(), src_host.data(), src.nbytes()); - const auto* upd_bytes = static_cast(upd_host.data()); - - for (uint32_t linear = 0; linear < slice_elems; ++linear) { - size_t remainder = linear; - size_t dst_offset = 0; - for (int d = 0; d < src.ndim(); ++d) { - size_t coord = 0; - if (!update_shape.empty()) { - coord = remainder / update_strides[d]; - remainder %= update_strides[d]; - } - if (d == axis) { - coord += normalized_index; - } - dst_offset += coord * src_strides[d]; - } - std::memcpy( - result.data() + dst_offset * item_size, - upd_bytes + static_cast(linear) * item_size, - item_size); - } - - void* host_result = std::malloc(result.size()); - if (host_result == nullptr) { - throw std::bad_alloc(); - } - std::memcpy(host_result, result.data(), result.size()); - array host_array( - host_result, src.shape(), src.dtype(), [](void* ptr) { std::free(ptr); }); - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace(host_array, out, CopyType::GeneralGeneral, s); - return true; -} - bool try_eval_scatter_vulkan( const std::vector& inputs, array& out, @@ -680,58 +455,6 @@ bool try_eval_scatter_vulkan( inputs[1].size(), "scatter_generic index_count"), slice_elems, "scatter_generic update_size") == upd.size()) { - std::vector flat_indices; - flat_indices.reserve(inputs.size() - 2); - for (int i = 1; i < inputs.size() - 1; ++i) { - flat_indices.push_back(reshape( - inputs[i], {static_cast(inputs[i].size())}, s)); - } - - std::vector reduced_axes; - std::vector reduced_indices; - reduced_axes.reserve(norm_axes.size()); - reduced_indices.reserve(norm_axes.size()); - bool dropped_full_slice_axis = false; - for (int j = 0; j < norm_axes.size(); ++j) { - const int axis = norm_axes[j]; - bool can_drop_axis = update_shape[axis] == src.shape(axis); - if (can_drop_axis) { - auto flat_idx = - ensure_host_readable_row_contiguous(flat_indices[j], s); - for (int i = 0; i < flat_idx.size(); ++i) { - if (read_contiguous_index(flat_idx, i) != 0) { - can_drop_axis = false; - break; - } - } - } - - if (can_drop_axis) { - dropped_full_slice_axis = true; - continue; - } - - reduced_axes.push_back(axis); - reduced_indices.push_back( - ensure_host_readable_row_contiguous(flat_indices[j], s)); - } - - if (dropped_full_slice_axis && reduced_axes.size() == 2 && - reduce_type == Scatter::Prod) { - std::vector reduced_inputs; - reduced_inputs.reserve(reduced_indices.size() + 2); - reduced_inputs.push_back(src); - reduced_inputs.insert( - reduced_inputs.end(), - reduced_indices.begin(), - reduced_indices.end()); - reduced_inputs.push_back(upd); - if (try_eval_scatter_vulkan( - reduced_inputs, out, reduce_type, reduced_axes, s)) { - return true; - } - } - if (try_dispatch_generic_scatter( inputs, norm_axes, @@ -786,7 +509,6 @@ bool try_eval_scatter_vulkan( checked_u32_size(idx0.size(), "scatter_pair index_count"); if (upd.ndim() == idx_ndim + src.ndim()) { Shape update_shape(upd.shape().begin() + idx_ndim, upd.shape().end()); - Shape target_slice_shape = update_shape; uint32_t slice_elems = 1; for (auto dim : update_shape) { slice_elems = checked_mul_u32( @@ -800,103 +522,9 @@ bool try_eval_scatter_vulkan( idx0.size(), "scatter_pair composed index_count"), slice_elems, "scatter_pair composed update_size") == upd.size() && - (reduce_type == Scatter::Prod || reduce_type == Scatter::Max || + (reduce_type == Scatter::None || reduce_type == Scatter::Sum || + reduce_type == Scatter::Prod || reduce_type == Scatter::Max || reduce_type == Scatter::Min)) { - idx0 = ensure_host_readable_row_contiguous(idx0, s); - idx1 = ensure_host_readable_row_contiguous(idx1, s); - upd = ensure_row_contiguous(upd, s); - Shape flat_shape = { - static_cast(index_count), - static_cast(slice_elems)}; - array flat_upd = reshape(upd, flat_shape, s); - array result = contiguous_copy_gpu(src, s); - result.set_status(array::Status::available); - for (int i = 0; i < idx0.size(); ++i) { - Shape start(src.ndim(), 0); - Shape stop = target_slice_shape; - Shape unit_strides(src.ndim(), 1); - auto idx0_value = normalize_scatter_index( - read_contiguous_index(idx0, i), src.shape(axis0)); - auto idx1_value = normalize_scatter_index( - read_contiguous_index(idx1, i), src.shape(axis1)); - start[axis0] = idx0_value; - start[axis1] = idx1_value; - stop[axis0] += start[axis0]; - stop[axis1] += start[axis1]; - if (stop[axis0] > src.shape(axis0) || - stop[axis1] > src.shape(axis1)) { - return false; - } - - array update_value = reshape( - slice( - flat_upd, - {static_cast(i), 0}, - {static_cast(i + 1), - static_cast(slice_elems)}, - s), - update_shape, - s); - - switch (reduce_type) { - case Scatter::None: { - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate op(s, SliceUpdate::None, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - break; - } - case Scatter::Sum: { - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate op(s, SliceUpdate::Sum, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - break; - } - case Scatter::Prod: { - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate op(s, SliceUpdate::Prod, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - break; - } - case Scatter::Max: { - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate op(s, SliceUpdate::Max, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - break; - } - case Scatter::Min: { - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate op(s, SliceUpdate::Min, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - break; - } - } - } - out.copy_shared_buffer(result); - return true; - } - - if (update_shape.size() == src.ndim() && - checked_mul_u32( - checked_u32_size( - idx0.size(), "scatter_pair composed index_count"), - slice_elems, - "scatter_pair composed update_size") == upd.size() && - (reduce_type == Scatter::None || reduce_type == Scatter::Sum)) { std::vector norm_axes = {axis0, axis1}; std::vector generic_inputs = {src, idx0, idx1, upd}; if (try_dispatch_generic_scatter( @@ -909,6 +537,7 @@ bool try_eval_scatter_vulkan( reduce_type)) { return true; } + return false; } } if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { @@ -1044,7 +673,6 @@ bool try_eval_scatter_vulkan( if (upd.ndim() == idx.ndim() + src.ndim()) { Shape update_shape(upd.shape().begin() + idx.ndim(), upd.shape().end()); - Shape target_slice_shape = update_shape; uint32_t slice_elems = 1; for (auto dim : update_shape) { slice_elems = checked_mul_u32( @@ -1057,7 +685,9 @@ bool try_eval_scatter_vulkan( checked_u32_size(idx.size(), "scatter composed index_count"), slice_elems, "scatter composed update_size") == upd.size() && - (reduce_type == Scatter::None || reduce_type == Scatter::Sum)) { + (reduce_type == Scatter::None || reduce_type == Scatter::Sum || + reduce_type == Scatter::Prod || reduce_type == Scatter::Max || + reduce_type == Scatter::Min)) { std::vector norm_axes = {axis}; if (try_dispatch_generic_scatter( inputs, @@ -1071,73 +701,6 @@ bool try_eval_scatter_vulkan( } return false; } - - if (update_shape.size() == src.ndim() && - checked_mul_u32( - checked_u32_size(idx.size(), "scatter composed index_count"), - slice_elems, - "scatter composed update_size") == upd.size() && - (reduce_type == Scatter::Prod || reduce_type == Scatter::Max || - reduce_type == Scatter::Min)) { - idx = ensure_host_readable_row_contiguous( - reshape(idx, {static_cast(index_count)}, s), s); - upd = ensure_row_contiguous(upd, s); - Shape flat_shape = { - static_cast(index_count), - static_cast(slice_elems)}; - array flat_upd = reshape(upd, flat_shape, s); - array result(src.shape(), src.dtype(), nullptr, {}); - result.set_data(allocator::malloc(result.nbytes())); - result.set_status(array::Status::available); - copy_gpu(src, result, source_copy_type(src), s); - - for (uint32_t i = 0; i < index_count; ++i) { - auto normalized_index = normalize_scatter_index( - read_contiguous_index(idx, i), src.shape(axis)); - Shape start(src.ndim(), 0); - Shape stop = target_slice_shape; - Shape unit_strides(src.ndim(), 1); - start[axis] = normalized_index; - stop[axis] += normalized_index; - if (stop[axis] > src.shape(axis)) { - return false; - } - - array update_value = reshape( - slice( - flat_upd, - {static_cast(i), 0}, - {static_cast(i + 1), - static_cast(slice_elems)}, - s), - update_shape, - s); - - array next(src.shape(), src.dtype(), nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - next.set_status(array::Status::available); - SliceUpdate::ReduceType op_reduce = SliceUpdate::None; - switch (reduce_type) { - case Scatter::Prod: - op_reduce = SliceUpdate::Prod; - break; - case Scatter::Max: - op_reduce = SliceUpdate::Max; - break; - case Scatter::Min: - op_reduce = SliceUpdate::Min; - break; - default: - break; - } - SliceUpdate op(s, op_reduce, start, stop, unit_strides); - op.eval_gpu({result, update_value}, next); - result = std::move(next); - } - - copy_gpu(result, out, CopyType::GeneralGeneral, s); - return true; - } } const uint32_t slice_size = take_slice_size; const uint32_t expected_update_size = checked_mul_u32( From 7a95237ec5af27a70b5d7cbb5ee91ad5b9b103d4 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sat, 20 Jun 2026 23:50:50 +0300 Subject: [PATCH 2/9] Address Vulkan scatter review feedback --- mlx/backend/vulkan/kernels/random_bits.comp | 1 + mlx/backend/vulkan/scatter.cpp | 41 ++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/mlx/backend/vulkan/kernels/random_bits.comp b/mlx/backend/vulkan/kernels/random_bits.comp index 114af0197a..667341f33d 100644 --- a/mlx/backend/vulkan/kernels/random_bits.comp +++ b/mlx/backend/vulkan/kernels/random_bits.comp @@ -3,6 +3,7 @@ #include "types.glsl" #extension GL_EXT_scalar_block_layout : require +#extension GL_EXT_shader_8bit_storage : require layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/mlx/backend/vulkan/scatter.cpp b/mlx/backend/vulkan/scatter.cpp index dd652964bc..e48dc60b42 100644 --- a/mlx/backend/vulkan/scatter.cpp +++ b/mlx/backend/vulkan/scatter.cpp @@ -72,6 +72,22 @@ bool supports_dynamic_scatter_sum_dtype(Dtype dtype) { return dtype == float32 || dtype == int32 || dtype == uint32; } +bool supports_dynamic_scatter_reduction_dtype( + Dtype dtype, + Scatter::ReduceType reduce_type) { + switch (reduce_type) { + case Scatter::None: + return true; + case Scatter::Sum: + return supports_dynamic_scatter_sum_dtype(dtype); + case Scatter::Prod: + case Scatter::Max: + case Scatter::Min: + return dtype == float32 || dtype == int32 || dtype == uint32; + } + return false; +} + std::string build_generic_scatter_shader( Dtype value_dtype, Dtype index_dtype, @@ -82,6 +98,10 @@ std::string build_generic_scatter_shader( const bool use_float_atomic_cas = value_dtype == float32 && (reduce_type == Scatter::Prod || reduce_type == Scatter::Max || reduce_type == Scatter::Min); + const bool use_integer_atomic_cas = + (value_dtype == int32 || value_dtype == uint32) && + (reduce_type == Scatter::Prod || reduce_type == Scatter::Max || + reduce_type == Scatter::Min); os << "#version 450\n"; os << "#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require\n"; @@ -177,6 +197,22 @@ std::string build_generic_scatter_shader( os << " old_bits = prev_bits;\n"; os << " }\n"; os << "}\n\n"; + } else if (use_integer_atomic_cas) { + os << "void atomic_reduce(uint dst_offset, VALUE_TYPE value) {\n"; + os << " VALUE_TYPE old_value = out_data[dst_offset];\n"; + os << " while (true) {\n"; + if (reduce_type == Scatter::Prod) { + os << " VALUE_TYPE new_value = old_value * value;\n"; + } else if (reduce_type == Scatter::Max) { + os << " VALUE_TYPE new_value = max(old_value, value);\n"; + } else { + os << " VALUE_TYPE new_value = min(old_value, value);\n"; + } + os << " VALUE_TYPE prev_value = atomicCompSwap(out_data[dst_offset], old_value, new_value);\n"; + os << " if (prev_value == old_value) break;\n"; + os << " old_value = prev_value;\n"; + os << " }\n"; + os << "}\n\n"; } os << "void main() {\n"; @@ -202,6 +238,8 @@ std::string build_generic_scatter_shader( os << "\n atomicAdd(out_data[dst_offset], upd_data[linear_idx]);\n"; } else if (use_float_atomic_cas) { os << "\n atomic_reduce(dst_offset, read_update(linear_idx));\n"; + } else if (use_integer_atomic_cas) { + os << "\n atomic_reduce(dst_offset, upd_data[linear_idx]);\n"; } else if (reduce_type == Scatter::Prod) { os << "\n out_data[dst_offset] *= upd_data[linear_idx];\n"; } else if (reduce_type == Scatter::Max) { @@ -236,8 +274,7 @@ bool try_dispatch_generic_scatter( const Dtype value_dtype = src_input.dtype(); const Dtype index_dtype = inputs[1].dtype(); - if (reduce_type == Scatter::Sum && - !supports_dynamic_scatter_sum_dtype(value_dtype)) { + if (!supports_dynamic_scatter_reduction_dtype(value_dtype, reduce_type)) { return false; } From c5875165ef778219797857d1550e239ff56fddac Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 00:02:32 +0300 Subject: [PATCH 3/9] Address Vulkan random and arange review --- mlx/backend/vulkan/kernels.cpp | 22 ++++++++++++++++------ mlx/backend/vulkan/kernels/arange.comp | 24 ++++++++++++++++++++++++ mlx/backend/vulkan/random.cpp | 2 +- 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 1a74ecba9f..ea1dbc2e8d 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/vulkan/kernels.h" #include #include +#include #include #include #include @@ -795,7 +796,7 @@ make_arange_push_constants_t(uint32_t num_elements, double start, double step) { ArangePushConstants push_constants{}; push_constants.KX = num_elements; - push_constants.KY = 1; + push_constants.KY = 0; push_constants.start_i64 = static_cast(start_t); push_constants.step_i64 = static_cast(step_t); push_constants.start_f32 = static_cast(start_t); @@ -823,11 +824,20 @@ ArangePushConstants make_arange_push_constants( return make_arange_push_constants_t(num_elements, start, step); case int64: return make_arange_push_constants_t(num_elements, start, step); - case float16: - return make_arange_push_constants_t(num_elements, start, step); - case bfloat16: - return make_arange_push_constants_t( - num_elements, start, step); + case float16: { + auto push_constants = + make_arange_push_constants_t(num_elements, start, step); + push_constants.KY = + start == std::trunc(start) && step == std::trunc(step) ? 1 : 0; + return push_constants; + } + case bfloat16: { + auto push_constants = + make_arange_push_constants_t(num_elements, start, step); + push_constants.KY = + start == std::trunc(start) && step == std::trunc(step) ? 1 : 0; + return push_constants; + } case float32: return make_arange_push_constants_t(num_elements, start, step); default: diff --git a/mlx/backend/vulkan/kernels/arange.comp b/mlx/backend/vulkan/kernels/arange.comp index a11819de6d..f0724ccbe4 100644 --- a/mlx/backend/vulkan/kernels/arange.comp +++ b/mlx/backend/vulkan/kernels/arange.comp @@ -23,6 +23,30 @@ void main() { return; } +#ifdef DATA_D_F16 + if (p.KY != 0) { + float16_t value = float16_t(p.start_f32); + float16_t step = float16_t(p.step_f32); + for (uint j = 0; j < i; ++j) { + value = float16_t(value + step); + } + data_d[i] = value; + return; + } +#endif + +#ifdef DATA_D_BF16 + if (p.KY != 0) { + uint32_t value = fp32_to_bf16(p.start_f32); + uint32_t step = fp32_to_bf16(p.step_f32); + for (uint j = 0; j < i; ++j) { + value = fp32_to_bf16(bf16_to_fp32(value) + bf16_to_fp32(step)); + } + data_d[i] = uint16_t(value); + return; + } +#endif + float value = p.start_f32 + p.step_f32 * float(i); #ifdef DATA_D_BF16 data_d[i] = uint16_t(fp32_to_bf16(value)); diff --git a/mlx/backend/vulkan/random.cpp b/mlx/backend/vulkan/random.cpp index 88a57561bf..2f10eaac70 100644 --- a/mlx/backend/vulkan/random.cpp +++ b/mlx/backend/vulkan/random.cpp @@ -59,7 +59,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { command_buffer, stream(), push_constants, - {static_cast(num_keys), + {static_cast((num_keys + 255) / 256), static_cast(half_size + odd), 1}); vulkan::end_command_recording(stream().index); From 9e4ecb7dd9e3b5490ba611d1a29edde52925bf10 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 00:23:27 +0300 Subject: [PATCH 4/9] Address Vulkan arange and scatter review --- mlx/backend/vulkan/kernels.cpp | 13 +++++++++---- mlx/backend/vulkan/scatter.cpp | 8 ++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index ea1dbc2e8d..1e67851a0a 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -21,6 +21,7 @@ namespace mlx::core::vulkan { constexpr uint32_t kMaxMulMatVecCols = 8; +constexpr uint32_t kMaxSequentialLowPrecisionArange = 4096; uint32_t matvec_rows_per_workgroup() { static const uint32_t value = []() { @@ -827,15 +828,19 @@ ArangePushConstants make_arange_push_constants( case float16: { auto push_constants = make_arange_push_constants_t(num_elements, start, step); - push_constants.KY = - start == std::trunc(start) && step == std::trunc(step) ? 1 : 0; + const bool use_sequential_low_precision = + num_elements <= kMaxSequentialLowPrecisionArange && + start == std::trunc(start) && step == std::trunc(step); + push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; } case bfloat16: { auto push_constants = make_arange_push_constants_t(num_elements, start, step); - push_constants.KY = - start == std::trunc(start) && step == std::trunc(step) ? 1 : 0; + const bool use_sequential_low_precision = + num_elements <= kMaxSequentialLowPrecisionArange && + start == std::trunc(start) && step == std::trunc(step); + push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; } case float32: diff --git a/mlx/backend/vulkan/scatter.cpp b/mlx/backend/vulkan/scatter.cpp index e48dc60b42..530ac2f168 100644 --- a/mlx/backend/vulkan/scatter.cpp +++ b/mlx/backend/vulkan/scatter.cpp @@ -574,7 +574,9 @@ bool try_eval_scatter_vulkan( reduce_type)) { return true; } - return false; + if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { + return false; + } } } if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { @@ -736,7 +738,9 @@ bool try_eval_scatter_vulkan( reduce_type)) { return true; } - return false; + if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { + return false; + } } } const uint32_t slice_size = take_slice_size; From fb5d93b5ce484d0fdb5056d9df4d7c101ced1185 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 01:26:36 +0300 Subject: [PATCH 5/9] Address Vulkan gather and arange review --- mlx/backend/vulkan/gather.cpp | 4 ++++ mlx/backend/vulkan/kernels.cpp | 3 --- mlx/backend/vulkan/kernels/arange.comp | 14 ++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mlx/backend/vulkan/gather.cpp b/mlx/backend/vulkan/gather.cpp index d7ea9df47a..6e36dddba7 100644 --- a/mlx/backend/vulkan/gather.cpp +++ b/mlx/backend/vulkan/gather.cpp @@ -428,6 +428,10 @@ bool try_dispatch_generic_gather( const int nidx = static_cast(norm_axes.size()); const Dtype value_dtype = src_input.dtype(); const Dtype index_dtype = inputs[1].dtype(); + if (index_dtype != int32 && index_dtype != uint32 && index_dtype != int64 && + index_dtype != uint64) { + return false; + } std::vector flat_indices; flat_indices.reserve(nidx); diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 1e67851a0a..8720fadf93 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -21,7 +21,6 @@ namespace mlx::core::vulkan { constexpr uint32_t kMaxMulMatVecCols = 8; -constexpr uint32_t kMaxSequentialLowPrecisionArange = 4096; uint32_t matvec_rows_per_workgroup() { static const uint32_t value = []() { @@ -829,7 +828,6 @@ ArangePushConstants make_arange_push_constants( auto push_constants = make_arange_push_constants_t(num_elements, start, step); const bool use_sequential_low_precision = - num_elements <= kMaxSequentialLowPrecisionArange && start == std::trunc(start) && step == std::trunc(step); push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; @@ -838,7 +836,6 @@ ArangePushConstants make_arange_push_constants( auto push_constants = make_arange_push_constants_t(num_elements, start, step); const bool use_sequential_low_precision = - num_elements <= kMaxSequentialLowPrecisionArange && start == std::trunc(start) && step == std::trunc(step); push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; diff --git a/mlx/backend/vulkan/kernels/arange.comp b/mlx/backend/vulkan/kernels/arange.comp index f0724ccbe4..f8622243e0 100644 --- a/mlx/backend/vulkan/kernels/arange.comp +++ b/mlx/backend/vulkan/kernels/arange.comp @@ -25,24 +25,30 @@ void main() { #ifdef DATA_D_F16 if (p.KY != 0) { + if (i != 0) { + return; + } float16_t value = float16_t(p.start_f32); float16_t step = float16_t(p.step_f32); - for (uint j = 0; j < i; ++j) { + for (uint j = 0; j < p.KX; ++j) { + data_d[j] = value; value = float16_t(value + step); } - data_d[i] = value; return; } #endif #ifdef DATA_D_BF16 if (p.KY != 0) { + if (i != 0) { + return; + } uint32_t value = fp32_to_bf16(p.start_f32); uint32_t step = fp32_to_bf16(p.step_f32); - for (uint j = 0; j < i; ++j) { + for (uint j = 0; j < p.KX; ++j) { + data_d[j] = uint16_t(value); value = fp32_to_bf16(bf16_to_fp32(value) + bf16_to_fp32(step)); } - data_d[i] = uint16_t(value); return; } #endif From a304c37f30c9714a24db1333f607e5bbdf663d2b Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 23:33:07 +0300 Subject: [PATCH 6/9] Address Vulkan random and scatter review --- mlx/backend/vulkan/kernels/random_bits.comp | 12 ++++++------ mlx/backend/vulkan/random.cpp | 19 +++++++++++++++++-- mlx/backend/vulkan/scatter.cpp | 4 ++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/mlx/backend/vulkan/kernels/random_bits.comp b/mlx/backend/vulkan/kernels/random_bits.comp index 667341f33d..242676425e 100644 --- a/mlx/backend/vulkan/kernels/random_bits.comp +++ b/mlx/backend/vulkan/kernels/random_bits.comp @@ -46,8 +46,7 @@ uvec2 threefry2x32_hash(uvec2 key, uvec2 count) { void main() { uint key_idx = gl_GlobalInvocationID.x; - uint half_idx = gl_GlobalInvocationID.y; - uint grid_y = gl_NumWorkGroups.y; + uint half_idx = gl_GlobalInvocationID.y + gl_WorkGroupID.z * gl_NumWorkGroups.y; if (key_idx >= p.num_keys) { return; @@ -55,8 +54,9 @@ void main() { uint half_size = p.out_skip / 2; bool even = (p.out_skip % 2) == 0; + uint logical_grid_y = half_size + (even ? 0 : 1); - if (half_idx >= half_size + (even ? 0 : 1)) { + if (half_idx >= logical_grid_y) { return; } @@ -69,7 +69,7 @@ void main() { // Generate random bits bool drop_last = !even && (half_idx == half_size); - uvec2 count = uvec2(half_idx, drop_last ? 0 : half_idx + grid_y); + uvec2 count = uvec2(half_idx, drop_last ? 0 : half_idx + logical_grid_y); uvec2 bits = threefry2x32_hash(key, count); uint key_byte_offset = key_idx * p.bytes_per_key; @@ -90,9 +90,9 @@ void main() { // Write second 4 bytes (if not dropping last). if (!drop_last) { - uint idx2 = out_offset + half_idx + grid_y; + uint idx2 = out_offset + half_idx + logical_grid_y; if (idx2 < out_offset + p.out_skip) { - WRITE_WORD(half_idx + grid_y, bits.y); + WRITE_WORD(half_idx + logical_grid_y, bits.y); } } } diff --git a/mlx/backend/vulkan/random.cpp b/mlx/backend/vulkan/random.cpp index 2f10eaac70..19df93bcd0 100644 --- a/mlx/backend/vulkan/random.cpp +++ b/mlx/backend/vulkan/random.cpp @@ -1,8 +1,10 @@ // Copyright © 2024 Apple Inc. +#include #include #include "mlx/backend/vulkan/primitives_utils.h" +#include "mlx/backend/vulkan/vulkan.h" namespace mlx::core { @@ -51,6 +53,19 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { odd ? 1u : 0u, static_cast(out_skip)}; + const auto limits = vulkan::VulkanContext::get() + .physical_device() + .getProperties() + .limits; + const uint32_t half_work = static_cast(half_size + odd); + const uint32_t grid_y = std::min(half_work, limits.maxComputeWorkGroupCount[1]); + const uint32_t grid_z = + static_cast((half_work + grid_y - 1) / grid_y); + if (grid_z > limits.maxComputeWorkGroupCount[2]) { + throw std::runtime_error( + "RandomBits failed on Vulkan (dispatch shape too large)."); + } + auto command_buffer = vulkan::begin_command_recording(stream().index); vulkan::dispatch_random_bits_op( keys, @@ -60,8 +75,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { stream(), push_constants, {static_cast((num_keys + 255) / 256), - static_cast(half_size + odd), - 1}); + grid_y, + grid_z}); vulkan::end_command_recording(stream().index); } diff --git a/mlx/backend/vulkan/scatter.cpp b/mlx/backend/vulkan/scatter.cpp index 530ac2f168..92e7048edb 100644 --- a/mlx/backend/vulkan/scatter.cpp +++ b/mlx/backend/vulkan/scatter.cpp @@ -273,6 +273,10 @@ bool try_dispatch_generic_scatter( const int nidx = static_cast(norm_axes.size()); const Dtype value_dtype = src_input.dtype(); const Dtype index_dtype = inputs[1].dtype(); + if (index_dtype != int32 && index_dtype != uint32 && index_dtype != int64 && + index_dtype != uint64) { + return false; + } if (!supports_dynamic_scatter_reduction_dtype(value_dtype, reduce_type)) { return false; From 9d129bb8d7c72f9a88ff3522f933900fcdc503b5 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 23:46:51 +0300 Subject: [PATCH 7/9] Address Vulkan scatter fallback review --- mlx/backend/vulkan/random.cpp | 9 +- mlx/backend/vulkan/scatter.cpp | 151 +++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 3 deletions(-) diff --git a/mlx/backend/vulkan/random.cpp b/mlx/backend/vulkan/random.cpp index 19df93bcd0..9a61e1d0f8 100644 --- a/mlx/backend/vulkan/random.cpp +++ b/mlx/backend/vulkan/random.cpp @@ -61,6 +61,11 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { const uint32_t grid_y = std::min(half_work, limits.maxComputeWorkGroupCount[1]); const uint32_t grid_z = static_cast((half_work + grid_y - 1) / grid_y); + const uint32_t grid_x = static_cast((num_keys + 255) / 256); + if (grid_x > limits.maxComputeWorkGroupCount[0]) { + throw std::runtime_error( + "RandomBits failed on Vulkan (dispatch shape too large)."); + } if (grid_z > limits.maxComputeWorkGroupCount[2]) { throw std::runtime_error( "RandomBits failed on Vulkan (dispatch shape too large)."); @@ -74,9 +79,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { command_buffer, stream(), push_constants, - {static_cast((num_keys + 255) / 256), - grid_y, - grid_z}); + {grid_x, grid_y, grid_z}); vulkan::end_command_recording(stream().index); } diff --git a/mlx/backend/vulkan/scatter.cpp b/mlx/backend/vulkan/scatter.cpp index 92e7048edb..2ecaa42601 100644 --- a/mlx/backend/vulkan/scatter.cpp +++ b/mlx/backend/vulkan/scatter.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include #include #include "mlx/backend/common/slicing.h" @@ -66,6 +67,130 @@ checked_shape_product(const array& arr, int begin, int end, const char* label) { return product; } +bool is_host_readable_index_constant(const array& idx) { + auto data = idx.data_shared_ptr(); + return !idx.has_primitive() && data != nullptr && data->buffer.ptr() != nullptr && + !vulkan::is_vulkan_buffer(data->buffer) && idx.flags().row_contiguous && + idx.offset() == 0 && idx.data_size() == idx.size(); +} + +int64_t read_contiguous_index(const array& idx, int i) { + switch (idx.dtype()) { + case int8: + return idx.data()[i]; + case int16: + return idx.data()[i]; + case int32: + return idx.data()[i]; + case int64: + return idx.data()[i]; + case uint8: + return idx.data()[i]; + case uint16: + return idx.data()[i]; + case uint32: + return idx.data()[i]; + case uint64: { + auto value = idx.data()[i]; + if (value > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("uint64 index exceeds max int64_t value"); + } + return static_cast(value); + } + default: + throw std::runtime_error("Unsupported index dtype for Vulkan scatter."); + } +} + +int64_t normalize_scatter_index(int64_t idx, int64_t axis_size) { + if (idx < 0) { + idx += axis_size; + } + return idx; +} + +SliceUpdate::ReduceType slice_update_reduce_type( + Scatter::ReduceType reduce_type) { + switch (reduce_type) { + case Scatter::Prod: + return SliceUpdate::Prod; + case Scatter::Max: + return SliceUpdate::Max; + case Scatter::Min: + return SliceUpdate::Min; + case Scatter::Sum: + return SliceUpdate::Sum; + case Scatter::None: + return SliceUpdate::None; + } + return SliceUpdate::None; +} + +bool try_slice_update_scatter_composed( + const array& src, + const std::vector& indices, + const array& upd, + array& out, + const std::vector& axes, + const Shape& update_shape, + uint32_t index_count, + uint32_t slice_elems, + Scatter::ReduceType reduce_type, + Stream s) { + for (const auto& idx : indices) { + if (!is_host_readable_index_constant(idx)) { + return false; + } + } + + array flat_upd = reshape( + ensure_row_contiguous(upd, s), + {static_cast(index_count), + static_cast(slice_elems)}, + s); + array result(src.shape(), src.dtype(), nullptr, {}); + result.set_data(allocator::malloc(result.nbytes())); + result.set_status(array::Status::available); + copy_gpu(src, result, source_copy_type(src), s); + + const auto op_reduce = slice_update_reduce_type(reduce_type); + for (uint32_t i = 0; i < index_count; ++i) { + Shape start(src.ndim(), 0); + Shape stop = update_shape; + Shape unit_strides(src.ndim(), 1); + for (int j = 0; j < axes.size(); ++j) { + const int axis = axes[j]; + const auto normalized_index = normalize_scatter_index( + read_contiguous_index(indices[j], i), src.shape(axis)); + start[axis] = normalized_index; + stop[axis] += normalized_index; + if (stop[axis] > src.shape(axis)) { + return false; + } + } + + array update_value = reshape( + slice( + flat_upd, + {static_cast(i), 0}, + {static_cast(i + 1), + static_cast(slice_elems)}, + s), + update_shape, + s); + + array next(src.shape(), src.dtype(), nullptr, {}); + next.set_data(allocator::malloc(next.nbytes())); + next.set_status(array::Status::available); + SliceUpdate op(s, op_reduce, start, stop, unit_strides); + op.eval_gpu({result, update_value}, next); + result = std::move(next); + } + + copy_gpu(result, out, CopyType::GeneralGeneral, s); + return true; +} + constexpr uint32_t kMaxScatterPushConstants = 128; bool supports_dynamic_scatter_sum_dtype(Dtype dtype) { @@ -579,6 +704,19 @@ bool try_eval_scatter_vulkan( return true; } if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { + if (try_slice_update_scatter_composed( + src, + {idx0, idx1}, + upd, + out, + norm_axes, + update_shape, + index_count, + slice_elems, + reduce_type, + s)) { + return true; + } return false; } } @@ -743,6 +881,19 @@ bool try_eval_scatter_vulkan( return true; } if (reduce_type != Scatter::None && reduce_type != Scatter::Sum) { + if (try_slice_update_scatter_composed( + src, + {idx}, + upd, + out, + norm_axes, + update_shape, + index_count, + slice_elems, + reduce_type, + s)) { + return true; + } return false; } } From 9ac80fd6a2fd32f6d90ce76f05244911afe751a7 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Sun, 21 Jun 2026 23:57:57 +0300 Subject: [PATCH 8/9] Address Vulkan random key and arange review --- mlx/backend/vulkan/kernels.cpp | 6 ++++++ mlx/backend/vulkan/random.cpp | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 8720fadf93..485fecd5f4 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -829,6 +829,9 @@ ArangePushConstants make_arange_push_constants( make_arange_push_constants_t(num_elements, start, step); const bool use_sequential_low_precision = start == std::trunc(start) && step == std::trunc(step); + if (use_sequential_low_precision) { + push_constants.step_f32 = static_cast(float16_t(step)); + } push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; } @@ -837,6 +840,9 @@ ArangePushConstants make_arange_push_constants( make_arange_push_constants_t(num_elements, start, step); const bool use_sequential_low_precision = start == std::trunc(start) && step == std::trunc(step); + if (use_sequential_low_precision) { + push_constants.step_f32 = static_cast(bfloat16_t(step)); + } push_constants.KY = use_sequential_low_precision ? 1 : 0; return push_constants; } diff --git a/mlx/backend/vulkan/random.cpp b/mlx/backend/vulkan/random.cpp index 9a61e1d0f8..bf23e58d38 100644 --- a/mlx/backend/vulkan/random.cpp +++ b/mlx/backend/vulkan/random.cpp @@ -31,7 +31,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { "RandomBits failed on Vulkan (only uint32 keys supported)."); } - if (!keys.flags().contiguous || keys.offset() != 0 || + if (!vulkan::is_vulkan_storage_array(keys) || !keys.flags().contiguous || + keys.offset() != 0 || keys.strides().back() != 1) { keys = contiguous_copy_gpu(keys, stream()); } From ee3007477388c99615beae01797ff72164d5cf41 Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Mon, 22 Jun 2026 00:11:49 +0300 Subject: [PATCH 9/9] Fix random bits shader extension order --- mlx/backend/vulkan/kernels/random_bits.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/vulkan/kernels/random_bits.comp b/mlx/backend/vulkan/kernels/random_bits.comp index 242676425e..6be271c6f3 100644 --- a/mlx/backend/vulkan/kernels/random_bits.comp +++ b/mlx/backend/vulkan/kernels/random_bits.comp @@ -1,10 +1,10 @@ #version 450 -#include "types.glsl" - #extension GL_EXT_scalar_block_layout : require #extension GL_EXT_shader_8bit_storage : require +#include "types.glsl" + layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Params {