Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ endfunction(build_kernel)

build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h)
build_kernel(dot)
build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
Expand Down
125 changes: 125 additions & 0 deletions mlx/backend/metal/kernels/dot.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright © 2026 Apple Inc.

#include <metal_simdgroup>

#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"

template <typename T>
[[kernel]] void dot_product(
const device T* a [[buffer(0)]],
const device T* b [[buffer(1)]],
device float* output [[buffer(2)]],
const constant int& n [[buffer(3)]],
uint tid [[thread_position_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]]) {
constexpr int ITEMS_PER_THREAD = 32;
constexpr int VEC = 16 / sizeof(T);
int start = (tg_id * 512 + simd_id * 32) * ITEMS_PER_THREAD + lane * VEC;

float c0 = 0.0f;
float c1 = 0.0f;
float c2 = 0.0f;
float c3 = 0.0f;

MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < ITEMS_PER_THREAD; i += VEC) {
int idx = start + i * ITEMS_PER_THREAD;
if (idx + VEC <= n) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < VEC; j += 4) {
c0 += float(a[idx + j + 0]) * float(b[idx + j + 0]);
c1 += float(a[idx + j + 1]) * float(b[idx + j + 1]);
c2 += float(a[idx + j + 2]) * float(b[idx + j + 2]);
c3 += float(a[idx + j + 3]) * float(b[idx + j + 3]);
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < VEC; ++j) {
int nidx = idx + j;
if (nidx < n) {
float v = float(a[nidx]) * float(b[nidx]);
switch (j & 3) {
case 0:
c0 += v;
break;
case 1:
c1 += v;
break;
case 2:
c2 += v;
break;
default:
c3 += v;
break;
}
}
}
}
}

threadgroup float smem[16];

float c = c0 + c1 + c2 + c3;
c = simd_sum(c);

if (lane == 0) {
smem[simd_id] = c;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

if (tid < 16) {
c = smem[tid];
c = simd_sum(c);
if (tid == 0) {
output[tg_id] = c;
}
}
}

template <typename T>
[[kernel]] void dot_reduce(
const device float* input [[buffer(0)]],
device T* output [[buffer(1)]],
const constant int& n [[buffer(2)]],
uint tid [[thread_position_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]) {
float c = 0.0f;
for (int i = int(tid); i < n; i += 512) {
c += input[i];
}

threadgroup float smem[16];

c = simd_sum(c);
if (lane == 0) {
smem[simd_id] = c;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

if (tid < 16) {
c = smem[tid];
c = simd_sum(c);
if (tid == 0) {
output[0] = T(c);
}
}
}

#define instantiate_dot_product_kernel(name, itype) \
instantiate_kernel("dot_product_" #name, dot_product, itype)

#define instantiate_dot_reduce_kernel(name, otype) \
instantiate_kernel("dot_reduce_" #name, dot_reduce, otype)

instantiate_dot_product_kernel(float32, float);
instantiate_dot_product_kernel(float16, half);
instantiate_dot_product_kernel(bfloat16, bfloat16_t);
instantiate_dot_reduce_kernel(float32, float);
instantiate_dot_reduce_kernel(float16, half);
instantiate_dot_reduce_kernel(bfloat16, bfloat16_t);
56 changes: 56 additions & 0 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,50 @@ void steel_matmul_axpby(
// GEMV dispatch
///////////////////////////////////////////////////////////////////////////////

void dot_product(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int K,
std::vector<array>& copies) {
constexpr int thread_group_size = 512;
constexpr int items_per_thread = 32;
auto& compute_encoder = metal::get_command_encoder(s);
std::string kname = "dot_product_" + type_to_name(a);
auto kernel = d.get_kernel(kname);

int n = K;
int threads = (n + items_per_thread - 1) / items_per_thread;
int blocks = (threads + thread_group_size - 1) / thread_group_size;

array partials({blocks}, float32, nullptr, {});
partials.set_data(allocator::malloc(partials.nbytes()));
copies.push_back(partials);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(partials, 2);
compute_encoder.set_bytes(n, 3);
compute_encoder.dispatch_threads(
MTL::Size(size_t(blocks) * thread_group_size, 1, 1),
MTL::Size(thread_group_size, 1, 1));

array current = partials;
kname = "dot_reduce_" + type_to_name(out);
kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(partials, 0);
compute_encoder.set_bytes(blocks, 2);
compute_encoder.set_output_array(out, 1);

compute_encoder.dispatch_threads(
MTL::Size(thread_group_size, 1, 1), MTL::Size(thread_group_size, 1, 1));

compute_encoder.add_temporaries(std::move(copies));
}

template <bool CHECK_AB = true>
void gemv_axbpy(
const Stream& s,
Expand Down Expand Up @@ -1290,6 +1334,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization

if (M == 1 && N == 1 && batch_size_out == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous && a.dtype() != complex64) {
return dot_product(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& out = */ out,
/* int K = */ K,
/* std::vector<array>& copies = */ copies);
}

// Route to gemv if needed
if (std::min(M, N) == 1) {
return gemv(
Expand Down