diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index edc169eec4..42677e1358 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -48,6 +48,7 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) +build_kernel(dot) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) diff --git a/mlx/backend/metal/kernels/dot.metal b/mlx/backend/metal/kernels/dot.metal new file mode 100644 index 0000000000..7c9259b6ef --- /dev/null +++ b/mlx/backend/metal/kernels/dot.metal @@ -0,0 +1,125 @@ +// Copyright © 2026 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void dot_product( + const device T* a [[buffer(0)]], + const device T* b [[buffer(1)]], + device float* output [[buffer(2)]], + const constant int& n [[buffer(3)]], + uint tid [[thread_position_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint tg_id [[threadgroup_position_in_grid]]) { + constexpr int ITEMS_PER_THREAD = 32; + constexpr int VEC = 16 / sizeof(T); + int start = (tg_id * 512 + simd_id * 32) * ITEMS_PER_THREAD + lane * VEC; + + float c0 = 0.0f; + float c1 = 0.0f; + float c2 = 0.0f; + float c3 = 0.0f; + + MLX_MTL_PRAGMA_UNROLL + for (int i = 0; i < ITEMS_PER_THREAD; i += VEC) { + int idx = start + i * ITEMS_PER_THREAD; + if (idx + VEC <= n) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < VEC; j += 4) { + c0 += float(a[idx + j + 0]) * float(b[idx + j + 0]); + c1 += float(a[idx + j + 1]) * float(b[idx + j + 1]); + c2 += float(a[idx + j + 2]) * float(b[idx + j + 2]); + c3 += float(a[idx + j + 3]) * float(b[idx + j + 3]); + } + } else { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < VEC; ++j) { + int nidx = idx + j; + if (nidx < n) { + float v = float(a[nidx]) * float(b[nidx]); + switch (j & 3) { + case 0: + c0 += v; + break; + case 1: + c1 += v; + break; + case 2: + c2 += v; + break; + default: + c3 += v; + break; + } + } + } + } + } + + threadgroup float smem[16]; + + float c = c0 + c1 + c2 + c3; + c = simd_sum(c); + + if (lane == 0) { + smem[simd_id] = c; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < 16) { + c = smem[tid]; + c = simd_sum(c); + if (tid == 0) { + output[tg_id] = c; + } + } +} + +template +[[kernel]] void dot_reduce( + const device float* input [[buffer(0)]], + device T* output [[buffer(1)]], + const constant int& n [[buffer(2)]], + uint tid [[thread_position_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]]) { + float c = 0.0f; + for (int i = int(tid); i < n; i += 512) { + c += input[i]; + } + + threadgroup float smem[16]; + + c = simd_sum(c); + if (lane == 0) { + smem[simd_id] = c; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < 16) { + c = smem[tid]; + c = simd_sum(c); + if (tid == 0) { + output[0] = T(c); + } + } +} + +#define instantiate_dot_product_kernel(name, itype) \ + instantiate_kernel("dot_product_" #name, dot_product, itype) + +#define instantiate_dot_reduce_kernel(name, otype) \ + instantiate_kernel("dot_reduce_" #name, dot_reduce, otype) + +instantiate_dot_product_kernel(float32, float); +instantiate_dot_product_kernel(float16, half); +instantiate_dot_product_kernel(bfloat16, bfloat16_t); +instantiate_dot_reduce_kernel(float32, float); +instantiate_dot_reduce_kernel(float16, half); +instantiate_dot_reduce_kernel(bfloat16, bfloat16_t); \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 01f7185fb7..315bf5fe27 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1039,6 +1039,50 @@ void steel_matmul_axpby( // GEMV dispatch /////////////////////////////////////////////////////////////////////////////// +void dot_product( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int K, + std::vector& copies) { + constexpr int thread_group_size = 512; + constexpr int items_per_thread = 32; + auto& compute_encoder = metal::get_command_encoder(s); + std::string kname = "dot_product_" + type_to_name(a); + auto kernel = d.get_kernel(kname); + + int n = K; + int threads = (n + items_per_thread - 1) / items_per_thread; + int blocks = (threads + thread_group_size - 1) / thread_group_size; + + array partials({blocks}, float32, nullptr, {}); + partials.set_data(allocator::malloc(partials.nbytes())); + copies.push_back(partials); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(partials, 2); + compute_encoder.set_bytes(n, 3); + compute_encoder.dispatch_threads( + MTL::Size(size_t(blocks) * thread_group_size, 1, 1), + MTL::Size(thread_group_size, 1, 1)); + + array current = partials; + kname = "dot_reduce_" + type_to_name(out); + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(partials, 0); + compute_encoder.set_bytes(blocks, 2); + compute_encoder.set_output_array(out, 1); + + compute_encoder.dispatch_threads( + MTL::Size(thread_group_size, 1, 1), MTL::Size(thread_group_size, 1, 1)); + + compute_encoder.add_temporaries(std::move(copies)); +} + template void gemv_axbpy( const Stream& s, @@ -1290,6 +1334,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Gemv specialization + if (M == 1 && N == 1 && batch_size_out == 1 && a.flags().row_contiguous && + b.flags().row_contiguous && a.dtype() != complex64) { + return dot_product( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int K = */ K, + /* std::vector& copies = */ copies); + } + // Route to gemv if needed if (std::min(M, N) == 1) { return gemv(