Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/src/dev/custom_metal_kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ JIT compiled. To reduce the overhead from that, build the kernel once with
Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.

Math Mode
---------

By default :func:`fast.metal_kernel` compiles kernels with
``compile_options={"math_mode": "safe"}`` so special values follow IEEE
behavior, for example ``exp(-inf) == 0``. This is important for kernels such as
masked softmax where causal or sliding-window masks depend on exponentiating
``-inf``.

If your kernel does not rely on these edge cases, you can opt in to less strict
math with ``compile_options={"math_mode": "relaxed"}`` or
``compile_options={"math_mode": "fast"}``:

.. code-block:: python

kernel = mx.fast.metal_kernel(
name="my_kernel",
input_names=["x"],
output_names=["y"],
source=source,
compile_options={"math_mode": "relaxed"},
)

The full function signature will be generated using:

* The shapes/dtypes of ``inputs``
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/common/metal_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ CustomKernelFunction metal_kernel(
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
bool atomic_outputs /* = false */,
CompileOptions compile_options /* = {} */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
Expand Down Expand Up @@ -360,7 +361,8 @@ CustomKernelFunction metal_kernel(
init_value,
std::vector<ScalarArg>{},
false,
0),
0,
static_cast<int>(compile_options.math_mode)),
std::move(inputs));
};
}
Expand Down
33 changes: 28 additions & 5 deletions mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,37 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"

namespace mlx::core::fast {

struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
std::unordered_map<std::string, std::pair<std::string, std::optional<int>>>
libraries;
};

static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};

std::optional<MTL::MathMode> to_mtl_math_mode(std::optional<int> math_mode) {
if (!math_mode) {
return std::nullopt;
}
switch (*math_mode) {
case static_cast<int>(MetalKernelMathMode::Safe):
return MTL::MathModeSafe;
case static_cast<int>(MetalKernelMathMode::Relaxed):
return MTL::MathModeRelaxed;
case static_cast<int>(MetalKernelMathMode::Fast):
return MTL::MathModeFast;
default:
throw std::invalid_argument("[metal_kernel] Invalid Metal math mode.");
}
}

void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down Expand Up @@ -58,17 +76,22 @@ void CustomKernel::eval_gpu(
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
if (it->second.first != source_ ||
it->second.second != metal_math_mode_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
it->second = {source_, metal_math_mode_};
}
} else {
kernel_cache.libraries.emplace(name_, source_);
kernel_cache.libraries.emplace(
name_, std::make_pair(source_, metal_math_mode_));
}
}

auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto lib = d.get_library(
name_,
[this] { return metal::utils() + source_; },
metal::CompileOptions{to_mtl_math_mode(metal_math_mode_)});
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);
Expand Down
28 changes: 25 additions & 3 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,30 @@ MTL::Library* Device::get_library(
}

NS::SharedPtr<MTL::Library> Device::build_library_(
const std::string& source_string) {
const std::string& source_string,
CompileOptions compile_options) {
auto pool = new_scoped_memory_pool();

auto ns_code =
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);

NS::Error* error = nullptr;
auto options = MTL::CompileOptions::alloc()->init()->autorelease();
options->setFastMathEnabled(false);
if (compile_options.math_mode) {
auto math_mode = *compile_options.math_mode;
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
options->setMathMode(math_mode);
} else {
if (math_mode == MTL::MathModeRelaxed) {
throw std::runtime_error(
"[metal::Device] Metal math mode `relaxed` requires macOS 15, "
"iOS 18, tvOS 18, or visionOS 2.");
}
options->setFastMathEnabled(math_mode == MTL::MathModeFast);
}
} else {
options->setFastMathEnabled(false);
}
options->setLanguageVersion(get_metal_version());
#ifndef NDEBUG
if (options->languageVersion() >= MTL::LanguageVersion3_2) {
Expand Down Expand Up @@ -756,6 +771,13 @@ NS::SharedPtr<MTL::ComputePipelineState> Device::get_kernel_(
MTL::Library* Device::get_library(
const std::string& name,
const std::function<std::string(void)>& builder) {
return get_library(name, builder, {});
}

MTL::Library* Device::get_library(
const std::string& name,
const std::function<std::string(void)>& builder,
CompileOptions compile_options) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
Expand All @@ -768,7 +790,7 @@ MTL::Library* Device::get_library(
return it->second.get();
}

auto mtl_lib = build_library_(builder());
auto mtl_lib = build_library_(builder(), compile_options);
library_map_.insert({name, mtl_lib});
return mtl_lib.get();
}
Expand Down
14 changes: 13 additions & 1 deletion mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Metal/Metal.hpp>
#include <functional>
#include <mutex>
#include <optional>
#include <shared_mutex>
#include <string>
#include <unordered_map>
Expand All @@ -19,6 +20,10 @@ namespace mlx::core::metal {
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;

struct CompileOptions {
std::optional<MTL::MathMode> math_mode = std::nullopt;
};

class Device;
class EventImpl;

Expand Down Expand Up @@ -170,6 +175,11 @@ class MLX_API Device {
const std::string& name,
const std::function<std::string(void)>& builder);

MTL::Library* get_library(
const std::string& name,
const std::function<std::string(void)>& builder,
CompileOptions compile_options);

void clear_library(const std::string& name);

MTL::ComputePipelineState* get_kernel(
Expand All @@ -190,7 +200,9 @@ class MLX_API Device {
}

private:
NS::SharedPtr<MTL::Library> build_library_(const std::string& source_string);
NS::SharedPtr<MTL::Library> build_library_(
const std::string& source_string,
CompileOptions compile_options = {});

NS::SharedPtr<MTL::Function> get_function_(
const std::string& name,
Expand Down
13 changes: 12 additions & 1 deletion mlx/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ MLX_API array scaled_dot_product_attention(
using TemplateArg = std::variant<int, bool, Dtype>;
using ScalarArg = std::variant<bool, int, float>;

enum class MetalKernelMathMode {
Safe = 0,
Relaxed = 1,
Fast = 2,
};

struct CompileOptions {
MetalKernelMathMode math_mode = MetalKernelMathMode::Safe;
};

using CustomKernelFunction = std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<Shape>&,
Expand All @@ -75,7 +85,8 @@ MLX_API CustomKernelFunction metal_kernel(
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false);
bool atomic_outputs = false,
CompileOptions compile_options = {});

MLX_API CustomKernelFunction cuda_kernel(
const std::string& name,
Expand Down
10 changes: 7 additions & 3 deletions mlx/fast_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class CustomKernel : public Primitive {
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
int shared_memory,
std::optional<int> metal_math_mode = std::nullopt)
: Primitive(stream),
name_(std::move(name)),
source_(std::move(source)),
Expand All @@ -386,7 +387,8 @@ class CustomKernel : public Primitive {
init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
shared_memory_(shared_memory),
metal_math_mode_(metal_math_mode) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
Expand All @@ -408,7 +410,8 @@ class CustomKernel : public Primitive {
init_value_,
scalar_arguments_,
is_precompiled_,
shared_memory_);
shared_memory_,
metal_math_mode_);
}

private:
Expand All @@ -422,6 +425,7 @@ class CustomKernel : public Primitive {
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
std::optional<int> metal_math_mode_;
};

} // namespace mlx::core::fast
53 changes: 51 additions & 2 deletions python/src/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,46 @@ struct PyCustomKernelFunction {
const char* tag_;
};

mx::fast::MetalKernelMathMode parse_metal_math_mode(
const std::string& math_mode) {
if (math_mode == "safe") {
return mx::fast::MetalKernelMathMode::Safe;
} else if (math_mode == "relaxed") {
return mx::fast::MetalKernelMathMode::Relaxed;
} else if (math_mode == "fast") {
return mx::fast::MetalKernelMathMode::Fast;
}
throw std::invalid_argument(
"[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'.");
}

mx::fast::CompileOptions parse_compile_options(
const nb::object& compile_options) {
mx::fast::CompileOptions options;

if (compile_options.is_none()) {
return options;
}

if (!nb::isinstance<nb::dict>(compile_options)) {
throw std::invalid_argument(
"[metal_kernel] Expected `compile_options` to be a dict.");
}

nb::dict dict = nb::cast<nb::dict>(compile_options);
for (auto [key, value] : dict) {
auto key_str = nb::cast<std::string>(key);
if (key_str == "math_mode") {
options.math_mode = parse_metal_math_mode(nb::cast<std::string>(value));
} else {
std::ostringstream msg;
msg << "[metal_kernel] Unknown compile option `" << key_str << "`.";
throw std::invalid_argument(msg.str());
}
}
return options;
}

} // namespace

void init_fast(nb::module_& parent_module) {
Expand Down Expand Up @@ -304,15 +344,17 @@ void init_fast(nb::module_& parent_module) {
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
bool atomic_outputs) {
bool atomic_outputs,
const nb::object& compile_options) {
auto kernel = mx::fast::metal_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
atomic_outputs);
atomic_outputs,
parse_compile_options(compile_options));
return nb::cpp_function(
PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
nb::kw_only(),
Expand Down Expand Up @@ -356,6 +398,7 @@ void init_fast(nb::module_& parent_module) {
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
"compile_options"_a = nb::none(),
R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string.

Expand All @@ -376,6 +419,12 @@ void init_fast(nb::module_& parent_module) {
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
compile_options (dict, optional): Options to compile the Metal kernel
with. Supported options:

* ``"math_mode"``: The Metal math mode: ``"safe"``, ``"relaxed"``,
or ``"fast"``. ``"safe"`` preserves IEEE behavior for special
values such as ``exp(-inf) == 0``. Default: ``"safe"``.

Returns:
Callable ``metal_kernel``.
Expand Down
Loading
Loading