Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions mlx/backend/vulkan/arange.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
// Copyright © 2024 Apple Inc.

#include <cmath>
#include <vector>

#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/vulkan/primitives_utils.h"
#include "mlx/backend/vulkan/vulkan.h"
#include "mlx/dtype.h"
Expand Down Expand Up @@ -49,36 +45,6 @@ bool try_eval_arange_vulkan(
return false;
}

if ((out.dtype() == float16 || out.dtype() == bfloat16) &&
start == std::trunc(start) && step == std::trunc(step)) {
// Match CPU low-precision arange semantics by advancing in the target dtype
// instead of recomputing each element from float32 math in the shader.
const auto n = out.size();
out.set_data(allocator::malloc(out.nbytes()));
if (n == 0) {
return true;
}
if (out.dtype() == float16) {
auto* dst = out.data<float16_t>();
float16_t value(static_cast<float>(start));
const float16_t step_value(static_cast<float>(step));
for (size_t i = 0; i < n; ++i) {
dst[i] = value;
value = float16_t(static_cast<float>(value) + static_cast<float>(step_value));
}
return true;
}

auto* dst = out.data<bfloat16_t>();
bfloat16_t value(static_cast<float>(start));
const bfloat16_t step_value(static_cast<float>(step));
for (size_t i = 0; i < n; ++i) {
dst[i] = value;
value = bfloat16_t(static_cast<float>(value) + static_cast<float>(step_value));
}
return true;
}

auto shader_id = arange_shader_id(out.dtype());
if (!shader_id.has_value()) {
return false;
Expand Down
268 changes: 66 additions & 202 deletions mlx/backend/vulkan/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ bool ensure_vulkan_storage(array& arr, Stream s) {
return vulkan::is_vulkan_storage_array(arr);
}

array ensure_host_readable_row_contiguous(array arr, Stream s) {
if (arr.has_primitive()) {
arr.eval();
}
if (needs_row_contiguous(arr)) {
arr = contiguous_copy_gpu(arr, s);
}
arr.wait();
return arr;
}

std::pair<array, bool> make_output_work(array& out) {
const bool staged_output = needs_row_contiguous(out);
array out_work =
Expand All @@ -69,6 +58,12 @@ checked_shape_product(const array& arr, int begin, int end, const char* label) {
return product;
}

bool is_host_readable_index_constant(const array& idx) {
auto data = idx.data_shared_ptr();
return !idx.has_primitive() && data != nullptr && data->buffer.ptr() != nullptr &&
!vulkan::is_vulkan_buffer(data->buffer);
}

std::string build_complex_gather_axis_shader() {
std::ostringstream os;
os << vulkan::emit_dynamic_shader_preamble(complex64, complex64, false);
Expand Down Expand Up @@ -257,7 +252,7 @@ bool try_eval_i64_gather_axis_vulkan(
}

std::optional<int64_t> scalar_index_value(const array& idx) {
if (idx.ndim() != 0) {
if (idx.ndim() != 0 || !is_host_readable_index_constant(idx)) {
return std::nullopt;
}
switch (idx.dtype()) {
Expand All @@ -280,7 +275,7 @@ std::optional<int64_t> scalar_index_value(const array& idx) {
}

std::optional<int64_t> singleton_index_value(const array& idx) {
if (idx.size() != 1) {
if (idx.size() != 1 || !is_host_readable_index_constant(idx)) {
return std::nullopt;
}
switch (idx.dtype()) {
Expand Down Expand Up @@ -314,44 +309,6 @@ int64_t normalize_gather_index(int64_t idx, int64_t axis_size) {
return idx;
}

int64_t read_contiguous_index(const array& idx, int i) {
switch (idx.dtype()) {
case int32:
return idx.data<int32_t>()[i];
case int64:
return idx.data<int64_t>()[i];
case uint32:
return idx.data<uint32_t>()[i];
case uint64: {
auto val = idx.data<uint64_t>()[i];
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
throw std::runtime_error("uint64 index exceeds max int64_t value");
}
return static_cast<int64_t>(val);
}
default:
throw std::runtime_error("Unsupported index dtype for Vulkan gather.");
}
}

bool is_full_range_index_for_axis(
const array& idx,
int64_t axis_size,
Stream s) {
if (axis_size <= 0 || idx.size() == 0 || (idx.size() % axis_size) != 0) {
return false;
}
auto flat_idx = ensure_row_contiguous(
reshape(idx, {static_cast<ShapeElem>(idx.size())}, s), s);
flat_idx.eval();
for (int i = 0; i < flat_idx.size(); ++i) {
if (read_contiguous_index(flat_idx, i) != (i % axis_size)) {
return false;
}
}
return true;
}

constexpr uint32_t kMaxGatherPushConstants = 128;

std::string build_generic_gather_shader(
Expand Down Expand Up @@ -471,6 +428,10 @@ bool try_dispatch_generic_gather(
const int nidx = static_cast<int>(norm_axes.size());
const Dtype value_dtype = src_input.dtype();
const Dtype index_dtype = inputs[1].dtype();
if (index_dtype != int32 && index_dtype != uint32 && index_dtype != int64 &&
index_dtype != uint64) {
return false;
}

std::vector<array> flat_indices;
flat_indices.reserve(nidx);
Expand Down Expand Up @@ -670,83 +631,8 @@ bool try_eval_gather_vulkan(
return true;
}

if (trace_fallback_enabled()) {
trace_fallback("generic_gather_gpu_unavailable fallback=host_loop");
}

std::vector<array> flat_indices;
flat_indices.reserve(inputs.size() - 1);
for (int i = 1; i < inputs.size(); ++i) {
flat_indices.push_back(ensure_host_readable_row_contiguous(
reshape(inputs[i], {static_cast<ShapeElem>(index_count)}, s), s));
}

auto [out_work, staged_output] = make_output_work(out);
if (out_work.size() == 0) {
if (staged_output) {
copy_gpu(out_work, out, CopyType::GeneralGeneral, s);
}
return true;
}

Strides out_slice_strides(
out_work.strides().begin() + idx_ndim, out_work.strides().end());
size_t out_slice_elems = 1;
for (auto dim : out_slice_shape) {
out_slice_elems *= static_cast<size_t>(dim);
}
auto [out_slice_data_size, out_slice_row_contig, out_slice_col_contig] =
check_contiguity(out_slice_shape, out_slice_strides);
array::Flags out_slice_flags = {
out_slice_data_size == out_slice_elems,
out_slice_row_contig,
out_slice_col_contig};

Strides index_shape_strides(idx_ndim, 1);
for (int i = idx_ndim - 2; i >= 0; --i) {
index_shape_strides[i] =
index_shape_strides[i + 1] * inputs[1].shape(i + 1);
}

for (uint32_t i = 0; i < index_count; ++i) {
Shape start(src_input.ndim(), 0);
Shape stop = slice_sizes;
Shape unit_strides(src_input.ndim(), 1);
for (int j = 0; j < norm_axes.size(); ++j) {
const int axis = norm_axes[j];
start[axis] = normalize_gather_index(
read_contiguous_index(flat_indices[j], i), src_input.shape(axis));
stop[axis] += start[axis];
if (stop[axis] > src_input.shape(axis)) {
return false;
}
}

array gathered = slice(src_input, start, stop, unit_strides, s);

int64_t out_offset = 0;
size_t remainder = i;
for (int d = 0; d < idx_ndim; ++d) {
const size_t coord = remainder / index_shape_strides[d];
remainder %= index_shape_strides[d];
out_offset += coord * out_work.strides(d);
}

array out_slice(out_slice_shape, out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out_work,
out_slice_strides,
out_slice_flags,
out_slice_data_size,
out_offset);
out_slice.set_status(array::Status::available);
copy_gpu_inplace(gathered, out_slice, CopyType::GeneralGeneral, s);
}

if (staged_output) {
copy_gpu(out_work, out, CopyType::GeneralGeneral, s);
}
return true;
trace_vulkan_unsupported("Gather", "generic gather GPU dispatch failed");
return false;
}

if (axes.size() == 2) {
Expand Down Expand Up @@ -803,78 +689,8 @@ bool try_eval_gather_vulkan(
return true;
}

idx0 = ensure_host_readable_row_contiguous(
reshape(idx0, {static_cast<ShapeElem>(idx0.size())}, s), s);
idx1 = ensure_host_readable_row_contiguous(
reshape(idx1, {static_cast<ShapeElem>(idx1.size())}, s), s);

auto [out_work, staged_output] = make_output_work(out);
if (out_work.size() == 0) {
if (staged_output) {
copy_gpu(out_work, out, CopyType::GeneralGeneral, s);
}
return true;
}

Strides out_slice_strides(
out_work.strides().begin() + idx_ndim, out_work.strides().end());
size_t out_slice_elems = 1;
for (auto dim : out_slice_shape) {
out_slice_elems *= static_cast<size_t>(dim);
}
auto [out_slice_data_size, out_slice_row_contig, out_slice_col_contig] =
check_contiguity(out_slice_shape, out_slice_strides);
array::Flags out_slice_flags = {
out_slice_data_size == out_slice_elems,
out_slice_row_contig,
out_slice_col_contig};

Strides index_shape_strides(idx_ndim, 1);
for (int i = idx_ndim - 2; i >= 0; --i) {
index_shape_strides[i] =
index_shape_strides[i + 1] * inputs[1].shape(i + 1);
}

for (uint32_t i = 0; i < index_count; ++i) {
Shape start(src_input.ndim(), 0);
Shape stop = slice_sizes;
Shape unit_strides(src_input.ndim(), 1);
start[axis0] = normalize_gather_index(
read_contiguous_index(idx0, i), src_input.shape(axis0));
start[axis1] = normalize_gather_index(
read_contiguous_index(idx1, i), src_input.shape(axis1));
stop[axis0] += start[axis0];
stop[axis1] += start[axis1];
if (stop[axis0] > src_input.shape(axis0) ||
stop[axis1] > src_input.shape(axis1)) {
return false;
}

array gathered = slice(src_input, start, stop, unit_strides, s);

int64_t out_offset = 0;
size_t remainder = i;
for (int d = 0; d < idx_ndim; ++d) {
const size_t coord = remainder / index_shape_strides[d];
remainder %= index_shape_strides[d];
out_offset += coord * out_work.strides(d);
}

array out_slice(out_slice_shape, out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out_work,
out_slice_strides,
out_slice_flags,
out_slice_data_size,
out_offset);
out_slice.set_status(array::Status::available);
copy_gpu_inplace(gathered, out_slice, CopyType::GeneralGeneral, s);
}

if (staged_output) {
copy_gpu(out_work, out, CopyType::GeneralGeneral, s);
}
return true;
trace_vulkan_unsupported("Gather", "pair gather GPU dispatch failed");
return false;
}

array src = ensure_row_contiguous(src_input, s);
Expand Down Expand Up @@ -950,6 +766,10 @@ bool try_eval_gather_vulkan(
trace_vulkan_unsupported("Gather", "axis is out of range");
return false;
}
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return true;
}
if (auto scalar_index = scalar_index_value(idx); scalar_index.has_value()) {
Shape start(src_input.ndim(), 0);
Shape stop = slice_sizes;
Expand Down Expand Up @@ -1000,17 +820,61 @@ bool try_eval_gather_vulkan(
s);
return true;
}

bool take_like_single_axis = true;
for (int i = 0; i < src_input.ndim(); ++i) {
const int64_t expected = (i == axis) ? 1 : src_input.shape(i);
if (slice_sizes[i] != expected) {
trace_vulkan_unsupported(
"Gather", "only take-like single-axis gathers are supported");
take_like_single_axis = false;
break;
}
}
if (!take_like_single_axis) {
const int idx_ndim = idx.ndim();
if (out.ndim() != idx_ndim + src_input.ndim()) {
return false;
}
Shape out_slice_shape(out.shape().begin() + idx_ndim, out.shape().end());
if (out_slice_shape != slice_sizes) {
return false;
}
std::vector<array> generic_inputs = {src_input, idx};
std::vector<int> norm_axes = {axis};
if (try_dispatch_generic_gather(
generic_inputs,
norm_axes,
slice_sizes,
idx_ndim,
checked_u32_size(idx.size(), "gather_single_axis index_count"),
out,
s)) {
return true;
}
trace_vulkan_unsupported(
"Gather", "single-axis generic gather GPU dispatch failed");
return false;
}

const auto shader_id = gather_shader_id(src_input.dtype(), idx.dtype());
if (!shader_id.has_value()) {
const int idx_ndim = idx.ndim();
if (out.ndim() == idx_ndim + src_input.ndim()) {
Shape out_slice_shape(out.shape().begin() + idx_ndim, out.shape().end());
if (out_slice_shape == slice_sizes) {
std::vector<array> generic_inputs = {src_input, idx};
std::vector<int> norm_axes = {axis};
if (try_dispatch_generic_gather(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep narrow indices out of generic gather

When a take-like gather uses int8/int16/uint8/uint16 indices, gather_shader_id is empty, so this new fallback dispatches the generic shader. That shader maps every non-int64/uint64/uint32 index dtype to int and does not enable 8/16-bit index storage, so an int16 index buffer is read as 32-bit entries; e.g. take(x, array([1, 2], int16), axis) can gather from packed pairs/out-of-bounds instead of elements 1 and 2. Either reject these index widths here or emit the correct index storage type.

Useful? React with 👍 / 👎.

generic_inputs,
norm_axes,
slice_sizes,
idx_ndim,
checked_u32_size(idx.size(), "gather_take_generic index_count"),
out,
s)) {
return true;
}
}
}
trace_vulkan_unsupported(
"Gather",
"value/index dtype combination is not supported by Vulkan gather");
Expand Down
Loading