From 53e3ff9dc0bbeea0e226be8e4624939571397900 Mon Sep 17 00:00:00 2001 From: Vedant Date: Fri, 22 May 2026 14:44:51 +0530 Subject: [PATCH 1/7] Route large 1D dot products through batched matmul --- mlx/ops.cpp | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 410f0a267f..a2284e2a67 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -58,6 +58,28 @@ Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } +array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { + constexpr int kChunkSize = 16384; + + int main_size = (a.size() / kChunkSize) * kChunkSize; + int batches = main_size / kChunkSize; + + auto a_main = + reshape(slice(a, {0}, {main_size}, s), {batches, 1, kChunkSize}, s); + auto b_main = + reshape(slice(b, {0}, {main_size}, s), {batches, kChunkSize, 1}, s); + + // Route large 1D dot products through batched matmul so gemv parallelizes + array total = sum(reshape(matmul(a_main, b_main, s), {batches}, s), false, s); + if (main_size == a.size()) { + return total; + } + + auto a_tail = slice(a, {main_size}, {static_cast(a.size())}, s); + auto b_tail = slice(b, {main_size}, {static_cast(b.size())}, s); + return add(total, sum(multiply(a_tail, b_tail, s), false, s), s); +} + array indices_or_default( std::optional indices, const array& x, @@ -5466,6 +5488,18 @@ array tensordot( } } + auto stream = to_stream(s); + auto device = stream.device; + if (a.has_primitive()) { + device = a.primitive().stream().device; + } else if (b.has_primitive()) { + device = b.primitive().stream().device; + } + if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && + axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 16384) { + return inner_1d_gpu_chunked(a, b, stream); + } + std::vector cdims1(x.ndim(), false); std::vector cdims2(y.ndim(), false); for (const auto n : axes_a) { From 88fdfb6dfbb15b8bafa57ff504c4c7c0aeac210e Mon Sep 17 00:00:00 2001 From: Vedant Date: Fri, 22 May 2026 15:24:57 +0530 Subject: [PATCH 2/7] Change chunk size --- mlx/ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a2284e2a67..febcab3151 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -59,7 +59,7 @@ Dtype at_least_float(const Dtype& d) { } array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { - constexpr int kChunkSize = 16384; + constexpr int kChunkSize = 4096; int main_size = (a.size() / kChunkSize) * kChunkSize; int batches = main_size / kChunkSize; @@ -5496,7 +5496,7 @@ array tensordot( device = b.primitive().stream().device; } if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && - axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 16384) { + axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 4096) { return inner_1d_gpu_chunked(a, b, stream); } From 47356cbf23d8d5c84aadbc7bb7ed831d03455b32 Mon Sep 17 00:00:00 2001 From: Vedant Date: Sun, 24 May 2026 10:44:58 +0530 Subject: [PATCH 3/7] Remove operand level checking --- mlx/ops.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index febcab3151..69bf1560d2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -5489,14 +5489,9 @@ array tensordot( } auto stream = to_stream(s); - auto device = stream.device; - if (a.has_primitive()) { - device = a.primitive().stream().device; - } else if (b.has_primitive()) { - device = b.primitive().stream().device; - } if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && - axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 4096) { + axes_b.size() == 1 && stream.device == Device::gpu && + csize >= 32 * 4096) { return inner_1d_gpu_chunked(a, b, stream); } From 4085c5f10ed125f80d6fdc5659a38f26d91fd909 Mon Sep 17 00:00:00 2001 From: Vedant Date: Tue, 16 Jun 2026 21:57:48 +0530 Subject: [PATCH 4/7] specialised kernel --- mlx/backend/metal/kernels/CMakeLists.txt | 2 + mlx/backend/metal/kernels/dot.metal | 123 +++++++++++++++++++++++ mlx/backend/metal/matmul.cpp | 74 ++++++++++++++ mlx/ops.cpp | 29 ------ 4 files changed, 199 insertions(+), 29 deletions(-) create mode 100644 mlx/backend/metal/kernels/dot.metal diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index edc169eec4..446f5c14d7 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -48,6 +48,8 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) +build_kernel(dot) +build_kernel(gemv steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) diff --git a/mlx/backend/metal/kernels/dot.metal b/mlx/backend/metal/kernels/dot.metal new file mode 100644 index 0000000000..98c61f76b0 --- /dev/null +++ b/mlx/backend/metal/kernels/dot.metal @@ -0,0 +1,123 @@ +// Copyright © 2026 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void dot_product( + const device T* a [[buffer(0)]], + const device T* b [[buffer(1)]], + device float* output [[buffer(2)]], + const constant int& n [[buffer(3)]], + uint gid [[thread_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint tg_id [[threadgroup_position_in_grid]]) { + constexpr int ITEMS_PER_THREAD = 32; + + int start = gid * ITEMS_PER_THREAD; + + float c0 = 0.0f; + float c1 = 0.0f; + float c2 = 0.0f; + float c3 = 0.0f; + + // Fast path: no per-element branches. + if (start + ITEMS_PER_THREAD <= n) { + MLX_MTL_PRAGMA_UNROLL + for (int i = 0; i < ITEMS_PER_THREAD; i += 4) { + c0 += float(a[start + i + 0]) * float(b[start + i + 0]); + c1 += float(a[start + i + 1]) * float(b[start + i + 1]); + c2 += float(a[start + i + 2]) * float(b[start + i + 2]); + c3 += float(a[start + i + 3]) * float(b[start + i + 3]); + } + } else { + // Tail path only for the last few threads. + MLX_MTL_PRAGMA_UNROLL + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + int idx = start + i; + if (idx < n) { + float v = float(a[idx]) * float(b[idx]); + switch (i & 3) { + case 0: + c0 += v; + break; + case 1: + c1 += v; + break; + case 2: + c2 += v; + break; + default: + c3 += v; + break; + } + } + } + } + + threadgroup float smem[16]; + + float c = c0 + c1 + c2 + c3; + c = simd_sum(c); + + if (lane == 0) { + smem[simd_id] = c; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < 16) { + c = smem[tid]; + c = simd_sum(c); + if (tid == 0) { + output[tg_id] = c; + } + } +} + +template +[[kernel]] void dot_reduce( + const device float* input [[buffer(0)]], + device T* output [[buffer(1)]], + const constant int& n [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint tg_id [[threadgroup_position_in_grid]]) { + float c = gid < uint(n) ? float(input[gid]) : 0.0f; + + threadgroup float smem[16]; + + c = simd_sum(c); + if (lane == 0) { + smem[simd_id] = c; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < 16) { + c = smem[tid]; + c = simd_sum(c); + if (tid == 0) { + output[tg_id] = T(c); + } + } +} + +#define instantiate_dot_product_kernel(name, itype) \ + instantiate_kernel("dot_product_" #name, dot_product, itype) + +#define instantiate_dot_reduce_kernel(name, otype) \ + instantiate_kernel("dot_reduce_" #name, dot_reduce, otype) + +instantiate_dot_product_kernel(float32, float); +instantiate_dot_product_kernel(float16, half); +instantiate_dot_product_kernel(bfloat16, bfloat16_t); +instantiate_dot_reduce_kernel(float32, float); +instantiate_dot_reduce_kernel(float16, half); +instantiate_dot_reduce_kernel(bfloat16, bfloat16_t); \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 01f7185fb7..2c6812d68c 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1039,6 +1039,68 @@ void steel_matmul_axpby( // GEMV dispatch /////////////////////////////////////////////////////////////////////////////// +void dot_product( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int K, + std::vector& copies) { + constexpr int thread_group_size = 512; + constexpr int items_per_thread = 32; + auto& compute_encoder = metal::get_command_encoder(s); + std::string kname = "dot_product_" + type_to_name(a); + auto kernel = d.get_kernel(kname); + + int n = K; + int threads = (n + items_per_thread - 1) / items_per_thread; + int blocks = (threads + thread_group_size - 1) / thread_group_size; + + array partials({blocks}, float32, nullptr, {}); + partials.set_data(allocator::malloc(partials.nbytes())); + copies.push_back(partials); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(partials, 2); + compute_encoder.set_bytes(n, 3); + compute_encoder.dispatch_threads( + MTL::Size(size_t(blocks) * thread_group_size, 1, 1), + MTL::Size(thread_group_size, 1, 1)); + + array current = partials; + kname = "dot_reduce_" + type_to_name(out); + auto kernel_final = d.get_kernel(kname); + auto kernel_intermediate = d.get_kernel("dot_reduce_float32"); + + while (blocks > 1) { + n = blocks; + blocks = (n + thread_group_size - 1) / thread_group_size; + + auto kernel = (blocks == 1) ? kernel_final : kernel_intermediate; + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(current, 0); + compute_encoder.set_bytes(n, 2); + + if (blocks == 1) { + compute_encoder.set_output_array(out, 1); + } else { + array next({blocks}, float32, nullptr, {}); + next.set_data(allocator::malloc(next.nbytes())); + copies.push_back(next); + compute_encoder.set_output_array(next, 1); + current = next; + } + + compute_encoder.dispatch_threads( + MTL::Size(size_t(blocks) * thread_group_size, 1, 1), + MTL::Size(thread_group_size, 1, 1)); + } + + compute_encoder.add_temporaries(std::move(copies)); +} + template void gemv_axbpy( const Stream& s, @@ -1290,6 +1352,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Gemv specialization + if (M == 1 && N == 1 && K > 16384 && batch_size_out == 1 && !a_transposed && + !b_transposed && a.flags().row_contiguous && b.flags().row_contiguous) { + return dot_product( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int K = */ K, + /* std::vector& copies = */ copies); + } + // Route to gemv if needed if (std::min(M, N) == 1) { return gemv( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 69bf1560d2..410f0a267f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -58,28 +58,6 @@ Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } -array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { - constexpr int kChunkSize = 4096; - - int main_size = (a.size() / kChunkSize) * kChunkSize; - int batches = main_size / kChunkSize; - - auto a_main = - reshape(slice(a, {0}, {main_size}, s), {batches, 1, kChunkSize}, s); - auto b_main = - reshape(slice(b, {0}, {main_size}, s), {batches, kChunkSize, 1}, s); - - // Route large 1D dot products through batched matmul so gemv parallelizes - array total = sum(reshape(matmul(a_main, b_main, s), {batches}, s), false, s); - if (main_size == a.size()) { - return total; - } - - auto a_tail = slice(a, {main_size}, {static_cast(a.size())}, s); - auto b_tail = slice(b, {main_size}, {static_cast(b.size())}, s); - return add(total, sum(multiply(a_tail, b_tail, s), false, s), s); -} - array indices_or_default( std::optional indices, const array& x, @@ -5488,13 +5466,6 @@ array tensordot( } } - auto stream = to_stream(s); - if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && - axes_b.size() == 1 && stream.device == Device::gpu && - csize >= 32 * 4096) { - return inner_1d_gpu_chunked(a, b, stream); - } - std::vector cdims1(x.ndim(), false); std::vector cdims2(y.ndim(), false); for (const auto n : axes_a) { From 9c34307d77e0c64107151e5258123f3c8782b61a Mon Sep 17 00:00:00 2001 From: Vedant Date: Wed, 17 Jun 2026 19:31:50 +0530 Subject: [PATCH 5/7] Restructuring kernel --- mlx/backend/metal/kernels/dot.metal | 76 +++++++++++++++-------------- mlx/backend/metal/matmul.cpp | 36 ++++---------- 2 files changed, 48 insertions(+), 64 deletions(-) diff --git a/mlx/backend/metal/kernels/dot.metal b/mlx/backend/metal/kernels/dot.metal index 98c61f76b0..7c9259b6ef 100644 --- a/mlx/backend/metal/kernels/dot.metal +++ b/mlx/backend/metal/kernels/dot.metal @@ -11,49 +11,50 @@ template const device T* b [[buffer(1)]], device float* output [[buffer(2)]], const constant int& n [[buffer(3)]], - uint gid [[thread_position_in_grid]], uint tid [[thread_position_in_threadgroup]], uint lane [[thread_index_in_simdgroup]], uint simd_id [[simdgroup_index_in_threadgroup]], uint tg_id [[threadgroup_position_in_grid]]) { constexpr int ITEMS_PER_THREAD = 32; - - int start = gid * ITEMS_PER_THREAD; + constexpr int VEC = 16 / sizeof(T); + int start = (tg_id * 512 + simd_id * 32) * ITEMS_PER_THREAD + lane * VEC; float c0 = 0.0f; float c1 = 0.0f; float c2 = 0.0f; float c3 = 0.0f; - // Fast path: no per-element branches. - if (start + ITEMS_PER_THREAD <= n) { - MLX_MTL_PRAGMA_UNROLL - for (int i = 0; i < ITEMS_PER_THREAD; i += 4) { - c0 += float(a[start + i + 0]) * float(b[start + i + 0]); - c1 += float(a[start + i + 1]) * float(b[start + i + 1]); - c2 += float(a[start + i + 2]) * float(b[start + i + 2]); - c3 += float(a[start + i + 3]) * float(b[start + i + 3]); - } - } else { - // Tail path only for the last few threads. - MLX_MTL_PRAGMA_UNROLL - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - int idx = start + i; - if (idx < n) { - float v = float(a[idx]) * float(b[idx]); - switch (i & 3) { - case 0: - c0 += v; - break; - case 1: - c1 += v; - break; - case 2: - c2 += v; - break; - default: - c3 += v; - break; + MLX_MTL_PRAGMA_UNROLL + for (int i = 0; i < ITEMS_PER_THREAD; i += VEC) { + int idx = start + i * ITEMS_PER_THREAD; + if (idx + VEC <= n) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < VEC; j += 4) { + c0 += float(a[idx + j + 0]) * float(b[idx + j + 0]); + c1 += float(a[idx + j + 1]) * float(b[idx + j + 1]); + c2 += float(a[idx + j + 2]) * float(b[idx + j + 2]); + c3 += float(a[idx + j + 3]) * float(b[idx + j + 3]); + } + } else { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < VEC; ++j) { + int nidx = idx + j; + if (nidx < n) { + float v = float(a[nidx]) * float(b[nidx]); + switch (j & 3) { + case 0: + c0 += v; + break; + case 1: + c1 += v; + break; + case 2: + c2 += v; + break; + default: + c3 += v; + break; + } } } } @@ -84,12 +85,13 @@ template const device float* input [[buffer(0)]], device T* output [[buffer(1)]], const constant int& n [[buffer(2)]], - uint gid [[thread_position_in_grid]], uint tid [[thread_position_in_threadgroup]], uint lane [[thread_index_in_simdgroup]], - uint simd_id [[simdgroup_index_in_threadgroup]], - uint tg_id [[threadgroup_position_in_grid]]) { - float c = gid < uint(n) ? float(input[gid]) : 0.0f; + uint simd_id [[simdgroup_index_in_threadgroup]]) { + float c = 0.0f; + for (int i = int(tid); i < n; i += 512) { + c += input[i]; + } threadgroup float smem[16]; @@ -104,7 +106,7 @@ template c = smem[tid]; c = simd_sum(c); if (tid == 0) { - output[tg_id] = T(c); + output[0] = T(c); } } } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 2c6812d68c..136d9ecc23 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1071,32 +1071,14 @@ void dot_product( array current = partials; kname = "dot_reduce_" + type_to_name(out); - auto kernel_final = d.get_kernel(kname); - auto kernel_intermediate = d.get_kernel("dot_reduce_float32"); - - while (blocks > 1) { - n = blocks; - blocks = (n + thread_group_size - 1) / thread_group_size; - - auto kernel = (blocks == 1) ? kernel_final : kernel_intermediate; - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(current, 0); - compute_encoder.set_bytes(n, 2); - - if (blocks == 1) { - compute_encoder.set_output_array(out, 1); - } else { - array next({blocks}, float32, nullptr, {}); - next.set_data(allocator::malloc(next.nbytes())); - copies.push_back(next); - compute_encoder.set_output_array(next, 1); - current = next; - } + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(partials, 0); + compute_encoder.set_bytes(blocks, 2); + compute_encoder.set_output_array(out, 1); - compute_encoder.dispatch_threads( - MTL::Size(size_t(blocks) * thread_group_size, 1, 1), - MTL::Size(thread_group_size, 1, 1)); - } + compute_encoder.dispatch_threads( + MTL::Size(thread_group_size, 1, 1), MTL::Size(thread_group_size, 1, 1)); compute_encoder.add_temporaries(std::move(copies)); } @@ -1352,8 +1334,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Gemv specialization - if (M == 1 && N == 1 && K > 16384 && batch_size_out == 1 && !a_transposed && - !b_transposed && a.flags().row_contiguous && b.flags().row_contiguous) { + if (M == 1 && N == 1 && batch_size_out == 1 && a.flags().row_contiguous && + b.flags().row_contiguous) { return dot_product( /* const Stream& s = */ s, /* metal::Device& d = */ d, From b1f0a191fbfaa3e06b93634bceb760bb7e738072 Mon Sep 17 00:00:00 2001 From: Vedant Date: Sun, 21 Jun 2026 03:06:03 +0530 Subject: [PATCH 6/7] add dtype check --- mlx/backend/metal/matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 136d9ecc23..315bf5fe27 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1335,7 +1335,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Gemv specialization if (M == 1 && N == 1 && batch_size_out == 1 && a.flags().row_contiguous && - b.flags().row_contiguous) { + b.flags().row_contiguous && a.dtype() != complex64) { return dot_product( /* const Stream& s = */ s, /* metal::Device& d = */ d, From 3f64c8b011b32f3036ab1d3458631e98e6ef3886 Mon Sep 17 00:00:00 2001 From: Vedant Date: Sun, 21 Jun 2026 03:24:57 +0530 Subject: [PATCH 7/7] Fix Cmake --- mlx/backend/metal/kernels/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 446f5c14d7..42677e1358 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -49,7 +49,6 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) build_kernel(dot) -build_kernel(gemv steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm)