diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index f5881b5c17..ce5880952f 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -50,6 +50,29 @@ JIT compiled. To reduce the overhead from that, build the kernel once with Only pass the body of the Metal kernel in ``source``. The function signature is generated automatically. +Math Mode +--------- + +By default :func:`fast.metal_kernel` compiles kernels with +``compile_options={"math_mode": "safe"}`` so special values follow IEEE +behavior, for example ``exp(-inf) == 0``. This is important for kernels such as +masked softmax where causal or sliding-window masks depend on exponentiating +``-inf``. + +If your kernel does not rely on these edge cases, you can opt in to less strict +math with ``compile_options={"math_mode": "relaxed"}`` or +``compile_options={"math_mode": "fast"}``: + +.. code-block:: python + + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["x"], + output_names=["y"], + source=source, + compile_options={"math_mode": "relaxed"}, + ) + The full function signature will be generated using: * The shapes/dtypes of ``inputs`` diff --git a/mlx/backend/common/metal_kernel.cpp b/mlx/backend/common/metal_kernel.cpp index 691feb554a..edbf75c46e 100644 --- a/mlx/backend/common/metal_kernel.cpp +++ b/mlx/backend/common/metal_kernel.cpp @@ -206,7 +206,8 @@ CustomKernelFunction metal_kernel( const std::string& source, const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { + bool atomic_outputs /* = false */, + CompileOptions compile_options /* = {} */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); @@ -360,7 +361,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + static_cast(compile_options.math_mode)), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 0ba491f4ff..1e47428088 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -3,12 +3,14 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" namespace mlx::core::fast { struct CustomKernelCache { - std::unordered_map libraries; + std::unordered_map>> + libraries; }; static CustomKernelCache& cache() { @@ -16,6 +18,22 @@ static CustomKernelCache& cache() { return cache_; }; +std::optional to_mtl_math_mode(std::optional math_mode) { + if (!math_mode) { + return std::nullopt; + } + switch (*math_mode) { + case static_cast(MetalKernelMathMode::Safe): + return MTL::MathModeSafe; + case static_cast(MetalKernelMathMode::Relaxed): + return MTL::MathModeRelaxed; + case static_cast(MetalKernelMathMode::Fast): + return MTL::MathModeFast; + default: + throw std::invalid_argument("[metal_kernel] Invalid Metal math mode."); + } +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -58,17 +76,22 @@ void CustomKernel::eval_gpu( auto& kernel_cache = cache(); if (auto it = kernel_cache.libraries.find(name_); it != kernel_cache.libraries.end()) { - if (it->second != source_) { + if (it->second.first != source_ || + it->second.second != metal_math_mode_) { auto& d = metal::device(s.device); d.clear_library(name_); - it->second = source_; + it->second = {source_, metal_math_mode_}; } } else { - kernel_cache.libraries.emplace(name_, source_); + kernel_cache.libraries.emplace( + name_, std::make_pair(source_, metal_math_mode_)); } } - auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); + auto lib = d.get_library( + name_, + [this] { return metal::utils() + source_; }, + metal::CompileOptions{to_mtl_math_mode(metal_math_mode_)}); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7658ce5f5c..de7ad6d159 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -606,7 +606,8 @@ MTL::Library* Device::get_library( } NS::SharedPtr Device::build_library_( - const std::string& source_string) { + const std::string& source_string, + CompileOptions compile_options) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -614,7 +615,21 @@ NS::SharedPtr Device::build_library_( NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init()->autorelease(); - options->setFastMathEnabled(false); + if (compile_options.math_mode) { + auto math_mode = *compile_options.math_mode; + if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { + options->setMathMode(math_mode); + } else { + if (math_mode == MTL::MathModeRelaxed) { + throw std::runtime_error( + "[metal::Device] Metal math mode `relaxed` requires macOS 15, " + "iOS 18, tvOS 18, or visionOS 2."); + } + options->setFastMathEnabled(math_mode == MTL::MathModeFast); + } + } else { + options->setFastMathEnabled(false); + } options->setLanguageVersion(get_metal_version()); #ifndef NDEBUG if (options->languageVersion() >= MTL::LanguageVersion3_2) { @@ -756,6 +771,13 @@ NS::SharedPtr Device::get_kernel_( MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { + return get_library(name, builder, {}); +} + +MTL::Library* Device::get_library( + const std::string& name, + const std::function& builder, + CompileOptions compile_options) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { @@ -768,7 +790,7 @@ MTL::Library* Device::get_library( return it->second.get(); } - auto mtl_lib = build_library_(builder()); + auto mtl_lib = build_library_(builder(), compile_options); library_map_.insert({name, mtl_lib}); return mtl_lib.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bed0cd636e..e04cb24228 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,10 @@ namespace mlx::core::metal { using MTLFCList = std::vector>; +struct CompileOptions { + std::optional math_mode = std::nullopt; +}; + class Device; class EventImpl; @@ -170,6 +175,11 @@ class MLX_API Device { const std::string& name, const std::function& builder); + MTL::Library* get_library( + const std::string& name, + const std::function& builder, + CompileOptions compile_options); + void clear_library(const std::string& name); MTL::ComputePipelineState* get_kernel( @@ -190,7 +200,9 @@ class MLX_API Device { } private: - NS::SharedPtr build_library_(const std::string& source_string); + NS::SharedPtr build_library_( + const std::string& source_string, + CompileOptions compile_options = {}); NS::SharedPtr get_function_( const std::string& name, diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..2e70ba8ac5 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -57,6 +57,16 @@ MLX_API array scaled_dot_product_attention( using TemplateArg = std::variant; using ScalarArg = std::variant; +enum class MetalKernelMathMode { + Safe = 0, + Relaxed = 1, + Fast = 2, +}; + +struct CompileOptions { + MetalKernelMathMode math_mode = MetalKernelMathMode::Safe; +}; + using CustomKernelFunction = std::function( const std::vector&, const std::vector&, @@ -75,7 +85,8 @@ MLX_API CustomKernelFunction metal_kernel( const std::string& source, const std::string& header = "", bool ensure_row_contiguous = true, - bool atomic_outputs = false); + bool atomic_outputs = false, + CompileOptions compile_options = {}); MLX_API CustomKernelFunction cuda_kernel( const std::string& name, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..4ffcfd3d56 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::optional metal_math_mode = std::nullopt) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + metal_math_mode_(metal_math_mode) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -408,7 +410,8 @@ class CustomKernel : public Primitive { init_value_, scalar_arguments_, is_precompiled_, - shared_memory_); + shared_memory_, + metal_math_mode_); } private: @@ -422,6 +425,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::optional metal_math_mode_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..af764ad33f 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -75,6 +75,46 @@ struct PyCustomKernelFunction { const char* tag_; }; +mx::fast::MetalKernelMathMode parse_metal_math_mode( + const std::string& math_mode) { + if (math_mode == "safe") { + return mx::fast::MetalKernelMathMode::Safe; + } else if (math_mode == "relaxed") { + return mx::fast::MetalKernelMathMode::Relaxed; + } else if (math_mode == "fast") { + return mx::fast::MetalKernelMathMode::Fast; + } + throw std::invalid_argument( + "[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'."); +} + +mx::fast::CompileOptions parse_compile_options( + const nb::object& compile_options) { + mx::fast::CompileOptions options; + + if (compile_options.is_none()) { + return options; + } + + if (!nb::isinstance(compile_options)) { + throw std::invalid_argument( + "[metal_kernel] Expected `compile_options` to be a dict."); + } + + nb::dict dict = nb::cast(compile_options); + for (auto [key, value] : dict) { + auto key_str = nb::cast(key); + if (key_str == "math_mode") { + options.math_mode = parse_metal_math_mode(nb::cast(value)); + } else { + std::ostringstream msg; + msg << "[metal_kernel] Unknown compile option `" << key_str << "`."; + throw std::invalid_argument(msg.str()); + } + } + return options; +} + } // namespace void init_fast(nb::module_& parent_module) { @@ -304,7 +344,8 @@ void init_fast(nb::module_& parent_module) { const std::string& source, const std::string& header, bool ensure_row_contiguous, - bool atomic_outputs) { + bool atomic_outputs, + const nb::object& compile_options) { auto kernel = mx::fast::metal_kernel( name, input_names, @@ -312,7 +353,8 @@ void init_fast(nb::module_& parent_module) { source, header, ensure_row_contiguous, - atomic_outputs); + atomic_outputs, + parse_compile_options(compile_options)); return nb::cpp_function( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), @@ -356,6 +398,7 @@ void init_fast(nb::module_& parent_module) { "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, + "compile_options"_a = nb::none(), R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. @@ -376,6 +419,12 @@ void init_fast(nb::module_& parent_module) { before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature e.g. ``device atomic``. Default: ``False``. + compile_options (dict, optional): Options to compile the Metal kernel + with. Supported options: + + * ``"math_mode"``: The Metal math mode: ``"safe"``, ``"relaxed"``, + or ``"fast"``. ``"safe"`` preserves IEEE behavior for special + values such as ``exp(-inf) == 0``. Default: ``"safe"``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 87e7b31ced..c45dd4a606 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -638,6 +638,38 @@ def call_cpu(a): with self.assertRaisesRegex(RuntimeError, "No Metal back-end"): mx.eval(mx.compile(call)(a)) + def test_export_custom_metal_kernel_with_math_mode(self): + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(a[elem]); + """ + kernel = mx.fast.metal_kernel( + name="math_mode_export", + input_names=["a"], + output_names=["out"], + source=source, + compile_options={"math_mode": "safe"}, + ) + + def call(a): + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(min(a.size, 256), 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.array([-float("inf"), 0.0]) + path = os.path.join(self.test_dir, "metal_kernel_math_mode.mlxfn") + mx.export_function(path, call, a) + self.assertTrue(os.path.exists(path)) + + if mx.metal.is_available(): + imported = mx.import_function(path) + self.assertTrue(mx.array_equal(imported(a)[0], call(a))) + def test_export_import_multi_with_constants(self): path = os.path.join(self.test_dir, "fn.mlxfn") diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b1c84c987d..e2baf93f95 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1026,6 +1026,49 @@ def call_kernel(a: mx.array, source): out = call_kernel(a, source) self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + def test_custom_metal_kernel_invalid_math_mode(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_math_mode", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + compile_options={"math_mode": "precise"}, + ) + + def test_custom_metal_kernel_invalid_compile_options(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_compile_options", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + compile_options={"unknown": "value"}, + ) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_metal_kernel_safe_math_mode(self): + kernel = mx.fast.metal_kernel( + name="safe_math_mode", + input_names=["inp"], + output_names=["out"], + source=""" + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(inp[elem]); + """, + compile_options={"math_mode": "safe"}, + ) + a = mx.array([-float("inf"), 0.0], dtype=mx.float32) + out = kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + self.assertTrue(mx.array_equal(out, mx.array([0.0, 1.0]))) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_mixed_dtypes(self): # Calling the same kernel with different input dtypes in a single