Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ if(MLX_METAL_JIT)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
kernels/fp4.h)
make_jit_source(gemv)
make_jit_source(gemv_masked)

make_jit_source(steel/attn/kernels/steel_attention)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/jit/includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const char* conv();
const char* steel_conv();
const char* steel_conv_3d();
const char* steel_conv_general();
const char* gemv();
const char* gemv_masked();
const char* steel_attention();

Expand Down
64 changes: 64 additions & 0 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,70 @@ MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}

MTL::ComputePipelineState* get_gemv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool nc,
bool axpby) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::gemv()
<< get_template_definition(
lib_name,
transpose_mat ? "gemv_t" : "gemv",
get_type_string(out.dtype()),
bm,
bn,
sm,
sn,
tm,
tn,
nc ? 1 : 0,
axpby ? 1 : 0);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_gemv_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::gemv()
<< get_template_definition(
lib_name,
transpose_mat ? "gemv_t_gather" : "gemv_gather",
get_type_string(out.dtype()),
bm,
bn,
sm,
sn,
tm,
tn);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
26 changes: 26 additions & 0 deletions mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ MTL::ComputePipelineState* get_steel_conv_3d_kernel(
int wn,
bool small_filter);

MTL::ComputePipelineState* get_gemv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool nc,
bool axpby);

MTL::ComputePipelineState* get_gemv_gather_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn);

MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ endfunction(build_kernel)

build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h)
build_kernel(gemv steel/utils.h)
build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
Expand Down Expand Up @@ -152,6 +151,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
build_kernel(gemv steel/utils.h)
build_kernel(gemv_masked steel/utils.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})

Expand Down
Loading
Loading