From 0523b266099af7260648eeae8dd42b02c6747afa Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 16 Jun 2026 13:31:46 -0700 Subject: [PATCH] Make gemv JIT compilable --- mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 64 ++ mlx/backend/metal/kernels.h | 26 + mlx/backend/metal/kernels/CMakeLists.txt | 2 +- mlx/backend/metal/kernels/gemv.h | 766 +++++++++++++++++++++++ mlx/backend/metal/kernels/gemv.metal | 758 +--------------------- mlx/backend/metal/matmul.cpp | 17 +- mlx/backend/metal/nojit_kernels.cpp | 30 + 9 files changed, 905 insertions(+), 760 deletions(-) create mode 100644 mlx/backend/metal/kernels/gemv.h diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index bbb08137f6..5eb8f543c7 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -79,6 +79,7 @@ if(MLX_METAL_JIT) make_jit_source(quantized kernels/quantized_utils.h) make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h kernels/fp4.h) + make_jit_source(gemv) make_jit_source(gemv_masked) make_jit_source(steel/attn/kernels/steel_attention) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index e22efa96d0..ac9fb81e26 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -43,6 +43,7 @@ const char* conv(); const char* steel_conv(); const char* steel_conv_3d(); const char* steel_conv_general(); +const char* gemv(); const char* gemv_masked(); const char* steel_attention(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 9c47b53b40..efdf19dc54 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -698,6 +698,70 @@ MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_gemv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn, + bool nc, + bool axpby) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::ostringstream kernel_source; + kernel_source << metal::gemv() + << get_template_definition( + lib_name, + transpose_mat ? "gemv_t" : "gemv", + get_type_string(out.dtype()), + bm, + bn, + sm, + sn, + tm, + tn, + nc ? 1 : 0, + axpby ? 1 : 0); + return kernel_source.str(); + }); + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_gemv_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::ostringstream kernel_source; + kernel_source << metal::gemv() + << get_template_definition( + lib_name, + transpose_mat ? "gemv_t_gather" : "gemv_gather", + get_type_string(out.dtype()), + bm, + bn, + sm, + sn, + tm, + tn); + return kernel_source.str(); + }); + return d.get_kernel(kernel_name, lib); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index dc0dab970d..fb291056f2 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -212,6 +212,32 @@ MTL::ComputePipelineState* get_steel_conv_3d_kernel( int wn, bool small_filter); +MTL::ComputePipelineState* get_gemv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn, + bool nc, + bool axpby); + +MTL::ComputePipelineState* get_gemv_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool transpose_mat, + int bm, + int bn, + int sm, + int sn, + int tm, + int tn); + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4b13e3ec57..3312fbb1e0 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -48,7 +48,6 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) -build_kernel(gemv steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) @@ -152,6 +151,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) + build_kernel(gemv steel/utils.h) build_kernel(gemv_masked steel/utils.h) build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS}) diff --git a/mlx/backend/metal/kernels/gemv.h b/mlx/backend/metal/kernels/gemv.h new file mode 100644 index 0000000000..84579516ec --- /dev/null +++ b/mlx/backend/metal/kernels/gemv.h @@ -0,0 +1,766 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_CONST static constant constexpr const + +template +struct DefaultAccT { + using type = float; +}; +template <> +struct DefaultAccT { + using type = complex64_t; +}; + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = typename DefaultAccT::type> +struct GEMVKernel { + using acc_type = AccT; + + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 4 || SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 4, 8, 16, or 32"); + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + template + static METAL_FUNC void + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } + + template + static METAL_FUNC void load_safe( + const device T* src, + thread U dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread AccT result[TM] = {0}; + thread T inter[TN]; + thread AccT v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Advance matrix + mat += size_t(out_row) * matrix_ld; + + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + load_unsafe(in_vec, v_coeff, bn); + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + + bn += blockN; + } + + if (leftover > 0) { + load_safe(in_vec, v_coeff, bn, in_size); + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + if (kDoAxpby) { + out_vec[out_row + tm] = + static_cast(alpha) * static_cast(result[tm]) + + static_cast(beta) * bias[(out_row + tm) * bias_stride]; + } else { + out_vec[out_row + tm] = static_cast(result[tm]); + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = typename DefaultAccT::type> +struct GEMVTKernel { + using acc_type = AccT; + + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + AccT result[TN] = {0}; + T inter[TN]; + AccT v_coeff[TM]; + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + auto vc = static_cast(v_coeff[tm]); + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += vc * inter[tn]; + } + } + + bm += blockM; + } + + if (leftover > 0) { + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + if (kDoAxpby) { + out_vec[out_col + j] = + static_cast(alpha) * static_cast(result[j]) + + static_cast(beta) * bias[(out_col + j) * bias_stride]; + } else { + out_vec[out_col + j] = static_cast(result[j]); + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVKernel; + threadgroup typename gemv_kernel::acc_type tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant int64_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant int64_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVKernel; + threadgroup typename gemv_kernel::acc_type tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant auto* veci_bstrides = index_batch_strides; + const constant auto* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVTKernel; + threadgroup typename gemv_kernel::acc_type tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant int64_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant int64_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVTKernel; + threadgroup typename gemv_kernel::acc_type tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant auto* veci_bstrides = index_batch_strides; + const constant auto* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index c1eac6f8fb..b306671559 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -3,510 +3,10 @@ #include #include -#include "mlx/backend/metal/kernels/utils.h" - -#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/gemv.h" using namespace metal; -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -#define MLX_MTL_CONST static constant constexpr const - -template -struct DefaultAccT { - using type = float; -}; -template <> -struct DefaultAccT { - using type = complex64_t; -}; - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = typename DefaultAccT::type> -struct GEMVKernel { - using acc_type = AccT; - - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - static_assert( - SN == 4 || SN == 8 || SN == 16 || SN == 32, - "gemv block must have a width of 4, 8, 16, or 32"); - - // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for - // the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated blockM outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; - - template - static METAL_FUNC void - load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } - - template - static METAL_FUNC void load_safe( - const device T* src, - thread U dst[TN], - const int src_offset = 0, - const int src_size = TN) { - if (src_offset + TN <= src_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } else { // Edgecase - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size - ? static_cast(src[src_offset + tn]) - : U(0); - } - } - } - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& matrix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& bias_stride [[buffer(14)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - thread AccT result[TM] = {0}; - thread T inter[TN]; - thread AccT v_coeff[TN]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); - const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; - - int bm = (simdM + thrM) * TM; - int bn = (simdN + thrN) * TN; - - // Block position - int out_row = tid.x * blockM + bm; - - // Exit simdgroup if rows out of bound - if (out_row >= out_vec_size) - return; - - // Adjust tail simdgroup to ensure in bound reads - out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; - - // Advance matrix - mat += size_t(out_row) * matrix_ld; - - constexpr const uniform loop_stride = make_uniform(blockN); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Loop over in_vec in blocks of blockN - for (int i = 0; i < n_iter; ++i) { - load_unsafe(in_vec, v_coeff, bn); - - // Per thread work loop - int mat_offset = 0; - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_unsafe(mat, inter, mat_offset + bn); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - - mat_offset += matrix_ld; - } - - bn += blockN; - } - - if (leftover > 0) { - load_safe(in_vec, v_coeff, bn, in_size); - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { - result[tm] += simd_shuffle_down(result[tm], sn); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; - if (thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - tgp_results[tm] = result[tm]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgn = 1; sgn < BN; sgn++) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] += tgp_results[sgn * (blockM + TM) + tm]; - } - } - } - } - } - - // Write outputs - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - if (kDoAxpby) { - out_vec[out_row + tm] = - static_cast(alpha) * static_cast(result[tm]) + - static_cast(beta) * bias[(out_row + tm) * bias_stride]; - } else { - out_vec[out_row + tm] = static_cast(result[tm]); - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = typename DefaultAccT::type> -struct GEMVTKernel { - using acc_type = AccT; - - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then accumulates its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& bias_stride [[buffer(14)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - AccT result[TN] = {0}; - T inter[TN]; - AccT v_coeff[TM]; - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = SM * sgM; - const int simdN = SN * sgN; - - int cm = (simdM + thrM); - int cn = (simdN + thrN); - - int bm = cm * TM; - int bn = cn * TN; - - int out_col = tid.x * blockN + bn; - - constexpr const uniform loop_stride = make_uniform(blockM); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Edgecase handling - if (out_col < out_vec_size) { - out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; - - // Per thread accumulation main loop - for (int i = 0; i < n_iter; ++i) { - // Adding a threadgroup_barrier improves performance slightly - // This is possibly it may help exploit cache better - threadgroup_barrier(mem_flags::mem_none); - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - auto vc = static_cast(v_coeff[tm]); - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += vc * inter[tn]; - } - } - - bm += blockM; - } - - if (leftover > 0) { - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[size_t(bm + tm) * marix_ld + out_col + tn]; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { - result[tn] += simd_shuffle_down(result[tn], SN * sm); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; - if (thrM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - tgp_results[tn] = result[tn]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgm = 1; sgm < BM; sgm++) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += tgp_results[sgm * (blockN + TN) + tn]; - } - } - } - } - } - - // Threadgroup accumulation and writing out results - if (cm == 0 && out_col < out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - if (kDoAxpby) { - out_vec[out_col + j] = - static_cast(alpha) * static_cast(result[j]) + - static_cast(beta) * bias[(out_col + j) * bias_stride]; - } else { - out_vec[out_col + j] = static_cast(result[j]); - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch, /* Batch ndim > 1 */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const constant int64_t* bias_batch_stride [[buffer(13)]], - const constant int& bias_stride [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (kDoAxpby) { - bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (kDoAxpby) { - bias += tid.z * bias_batch_stride[0]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - #define instantiate_gemv_helper( \ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ instantiate_kernel( \ @@ -545,96 +45,6 @@ instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); instantiate_gemv_blocks(complex64, complex64_t); -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* index_batch_strides [[buffer(11)]], - const constant int& vector_batch_ndim [[buffer(12)]], - const constant int* vector_batch_shape [[buffer(13)]], - const constant int64_t* vector_batch_stride [[buffer(14)]], - const constant int& matrix_batch_ndim [[buffer(15)]], - const constant int* matrix_batch_shape [[buffer(16)]], - const constant int64_t* matrix_batch_stride [[buffer(17)]], - const constant uint32_t* vec_indices [[buffer(18)]], - const constant uint32_t* mat_indices [[buffer(19)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - uint32_t indx_vec; - uint32_t indx_mat; - - // Update batch offsets - if (batch_ndim > 1) { - const constant auto* veci_bstrides = index_batch_strides; - const constant auto* mati_bstrides = index_batch_strides + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); - - indx_vec = vec_indices[batch_offsets.x]; - indx_mat = mat_indices[batch_offsets.y]; - - } else { - indx_vec = vec_indices[index_batch_strides[0] * tid.z]; - indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; - } - - if (vector_batch_ndim > 1) { - in_vec += elem_to_loc( - indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); - } else { - in_vec += indx_vec * vector_batch_stride[0]; - } - - if (matrix_batch_ndim > 1) { - mat += elem_to_loc( - indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); - } else { - mat += indx_mat * matrix_batch_stride[0]; - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - batch_ndim, // Not used - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - // clang-format off #define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ instantiate_kernel( \ @@ -652,82 +62,6 @@ instantiate_gemv_bs_blocks(float16, half); instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); instantiate_gemv_bs_blocks(complex64, complex64_t); -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch, /* Batch ndim > 1 */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const constant int64_t* bias_batch_stride [[buffer(13)]], - const constant int& bias_stride [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (kDoAxpby) { - bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (kDoAxpby) { - bias += tid.z * bias_batch_stride[0]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - // clang-format off #define instantiate_gemv_t_helper( \ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ @@ -756,96 +90,6 @@ instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(bfloat16, bfloat16_t); instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* index_batch_strides [[buffer(11)]], - const constant int& vector_batch_ndim [[buffer(12)]], - const constant int* vector_batch_shape [[buffer(13)]], - const constant int64_t* vector_batch_stride [[buffer(14)]], - const constant int& matrix_batch_ndim [[buffer(15)]], - const constant int* matrix_batch_shape [[buffer(16)]], - const constant int64_t* matrix_batch_stride [[buffer(17)]], - const constant uint32_t* vec_indices [[buffer(18)]], - const constant uint32_t* mat_indices [[buffer(19)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - uint32_t indx_vec; - uint32_t indx_mat; - - // Update batch offsets - if (batch_ndim > 1) { - const constant auto* veci_bstrides = index_batch_strides; - const constant auto* mati_bstrides = index_batch_strides + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); - - indx_vec = vec_indices[batch_offsets.x]; - indx_mat = mat_indices[batch_offsets.y]; - - } else { - indx_vec = vec_indices[index_batch_strides[0] * tid.z]; - indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; - } - - if (vector_batch_ndim > 1) { - in_vec += elem_to_loc( - indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); - } else { - in_vec += indx_vec * vector_batch_stride[0]; - } - - if (matrix_batch_ndim > 1) { - mat += elem_to_loc( - indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); - } else { - mat += indx_mat * matrix_batch_stride[0]; - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - batch_ndim, // Not used, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - // clang-format off #define instantiate_gemv_t_bs_helper( \ nm, itype, bm, bn, sm, sn, tm, tn) \ diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 5948bf58b3..01f7185fb7 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1142,7 +1142,19 @@ void gemv_axbpy( // Encode and dispatch kernel auto& compute_encoder = metal::get_command_encoder(s); - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_gemv_kernel( + d, + kname.str(), + out, + transpose_mat, + bm, + bn, + sm, + sn, + tm, + tn, + !contiguous_kernel, + do_axpby); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; @@ -2214,7 +2226,8 @@ void gather_mv( << tm << "_tn" << tn; // Encode and dispatch kernel - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_gemv_gather_kernel( + d, kname.str(), out, transpose_mat, bm, bn, sm, sn, tm, tn); compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 2ed74f470a..64cfa39ed5 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -226,6 +226,36 @@ MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_gemv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + bool, + int, + int, + int, + int, + int, + int, + bool, + bool) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_gemv_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + bool, + int, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name,