From c77b70ae60eb941f9d648de5aa2366d48a2590df Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 04:30:24 +0530 Subject: [PATCH 1/3] Add math mode option for custom Metal kernels --- docs/src/dev/custom_metal_kernels.rst | 21 +++++++++++++++++ mlx/backend/common/metal_kernel.cpp | 6 +++-- mlx/backend/metal/custom_kernel.cpp | 32 ++++++++++++++++++++++---- mlx/backend/metal/device.cpp | 27 +++++++++++++++++++--- mlx/backend/metal/device.h | 10 +++++++- mlx/fast.h | 9 +++++++- mlx/fast_primitives.h | 10 +++++--- python/src/fast.cpp | 24 +++++++++++++++++-- python/tests/test_export_import.py | 32 ++++++++++++++++++++++++++ python/tests/test_fast.py | 33 +++++++++++++++++++++++++++ 10 files changed, 187 insertions(+), 17 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index f5881b5c17..9629475619 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -50,6 +50,27 @@ JIT compiled. To reduce the overhead from that, build the kernel once with Only pass the body of the Metal kernel in ``source``. The function signature is generated automatically. +Math Mode +--------- + +By default :func:`fast.metal_kernel` compiles kernels with ``math_mode="safe"`` +so special values follow IEEE behavior, for example ``exp(-inf) == 0``. This is +important for kernels such as masked softmax where causal or sliding-window +masks depend on exponentiating ``-inf``. + +If your kernel does not rely on these edge cases, you can opt in to less strict +math with ``math_mode="relaxed"`` or ``math_mode="fast"``: + +.. code-block:: python + + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["x"], + output_names=["y"], + source=source, + math_mode="relaxed", + ) + The full function signature will be generated using: * The shapes/dtypes of ``inputs`` diff --git a/mlx/backend/common/metal_kernel.cpp b/mlx/backend/common/metal_kernel.cpp index 691feb554a..af09cacd3b 100644 --- a/mlx/backend/common/metal_kernel.cpp +++ b/mlx/backend/common/metal_kernel.cpp @@ -206,7 +206,8 @@ CustomKernelFunction metal_kernel( const std::string& source, const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { + bool atomic_outputs /* = false */, + MetalKernelMathMode math_mode /* = MetalKernelMathMode::Safe */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); @@ -360,7 +361,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + static_cast(math_mode)), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 0ba491f4ff..23bb8891bd 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -8,7 +8,8 @@ namespace mlx::core::fast { struct CustomKernelCache { - std::unordered_map libraries; + std::unordered_map>> + libraries; }; static CustomKernelCache& cache() { @@ -16,6 +17,22 @@ static CustomKernelCache& cache() { return cache_; }; +std::optional to_mtl_math_mode(std::optional math_mode) { + if (!math_mode) { + return std::nullopt; + } + switch (*math_mode) { + case static_cast(MetalKernelMathMode::Safe): + return MTL::MathModeSafe; + case static_cast(MetalKernelMathMode::Relaxed): + return MTL::MathModeRelaxed; + case static_cast(MetalKernelMathMode::Fast): + return MTL::MathModeFast; + default: + throw std::invalid_argument("[metal_kernel] Invalid Metal math mode."); + } +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -58,17 +75,22 @@ void CustomKernel::eval_gpu( auto& kernel_cache = cache(); if (auto it = kernel_cache.libraries.find(name_); it != kernel_cache.libraries.end()) { - if (it->second != source_) { + if (it->second.first != source_ || + it->second.second != metal_math_mode_) { auto& d = metal::device(s.device); d.clear_library(name_); - it->second = source_; + it->second = {source_, metal_math_mode_}; } } else { - kernel_cache.libraries.emplace(name_, source_); + kernel_cache.libraries.emplace( + name_, std::make_pair(source_, metal_math_mode_)); } } - auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); + auto lib = d.get_library( + name_, + [this] { return metal::utils() + source_; }, + to_mtl_math_mode(metal_math_mode_)); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7658ce5f5c..e915bbec8d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -606,7 +606,8 @@ MTL::Library* Device::get_library( } NS::SharedPtr Device::build_library_( - const std::string& source_string) { + const std::string& source_string, + std::optional math_mode) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -614,7 +615,20 @@ NS::SharedPtr Device::build_library_( NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init()->autorelease(); - options->setFastMathEnabled(false); + if (math_mode) { + if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { + options->setMathMode(*math_mode); + } else { + if (*math_mode == MTL::MathModeRelaxed) { + throw std::runtime_error( + "[metal::Device] Metal math mode `relaxed` requires macOS 15, " + "iOS 18, tvOS 18, or visionOS 2."); + } + options->setFastMathEnabled(*math_mode == MTL::MathModeFast); + } + } else { + options->setFastMathEnabled(false); + } options->setLanguageVersion(get_metal_version()); #ifndef NDEBUG if (options->languageVersion() >= MTL::LanguageVersion3_2) { @@ -756,6 +770,13 @@ NS::SharedPtr Device::get_kernel_( MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { + return get_library(name, builder, std::nullopt); +} + +MTL::Library* Device::get_library( + const std::string& name, + const std::function& builder, + std::optional math_mode) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { @@ -768,7 +789,7 @@ MTL::Library* Device::get_library( return it->second.get(); } - auto mtl_lib = build_library_(builder()); + auto mtl_lib = build_library_(builder(), math_mode); library_map_.insert({name, mtl_lib}); return mtl_lib.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bed0cd636e..d9cb910c4e 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -170,6 +171,11 @@ class MLX_API Device { const std::string& name, const std::function& builder); + MTL::Library* get_library( + const std::string& name, + const std::function& builder, + std::optional math_mode); + void clear_library(const std::string& name); MTL::ComputePipelineState* get_kernel( @@ -190,7 +196,9 @@ class MLX_API Device { } private: - NS::SharedPtr build_library_(const std::string& source_string); + NS::SharedPtr build_library_( + const std::string& source_string, + std::optional math_mode = std::nullopt); NS::SharedPtr get_function_( const std::string& name, diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..58b4ad8a7d 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -57,6 +57,12 @@ MLX_API array scaled_dot_product_attention( using TemplateArg = std::variant; using ScalarArg = std::variant; +enum class MetalKernelMathMode { + Safe = 0, + Relaxed = 1, + Fast = 2, +}; + using CustomKernelFunction = std::function( const std::vector&, const std::vector&, @@ -75,7 +81,8 @@ MLX_API CustomKernelFunction metal_kernel( const std::string& source, const std::string& header = "", bool ensure_row_contiguous = true, - bool atomic_outputs = false); + bool atomic_outputs = false, + MetalKernelMathMode math_mode = MetalKernelMathMode::Safe); MLX_API CustomKernelFunction cuda_kernel( const std::string& name, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..4ffcfd3d56 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::optional metal_math_mode = std::nullopt) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + metal_math_mode_(metal_math_mode) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -408,7 +410,8 @@ class CustomKernel : public Primitive { init_value_, scalar_arguments_, is_precompiled_, - shared_memory_); + shared_memory_, + metal_math_mode_); } private: @@ -422,6 +425,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::optional metal_math_mode_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..fc8af5a9da 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -75,6 +75,19 @@ struct PyCustomKernelFunction { const char* tag_; }; +mx::fast::MetalKernelMathMode parse_metal_math_mode( + const std::string& math_mode) { + if (math_mode == "safe") { + return mx::fast::MetalKernelMathMode::Safe; + } else if (math_mode == "relaxed") { + return mx::fast::MetalKernelMathMode::Relaxed; + } else if (math_mode == "fast") { + return mx::fast::MetalKernelMathMode::Fast; + } + throw std::invalid_argument( + "[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'."); +} + } // namespace void init_fast(nb::module_& parent_module) { @@ -304,7 +317,8 @@ void init_fast(nb::module_& parent_module) { const std::string& source, const std::string& header, bool ensure_row_contiguous, - bool atomic_outputs) { + bool atomic_outputs, + const std::string& math_mode) { auto kernel = mx::fast::metal_kernel( name, input_names, @@ -312,7 +326,8 @@ void init_fast(nb::module_& parent_module) { source, header, ensure_row_contiguous, - atomic_outputs); + atomic_outputs, + parse_metal_math_mode(math_mode)); return nb::cpp_function( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), @@ -356,6 +371,7 @@ void init_fast(nb::module_& parent_module) { "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, + "math_mode"_a = "safe", R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. @@ -376,6 +392,10 @@ void init_fast(nb::module_& parent_module) { before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature e.g. ``device atomic``. Default: ``False``. + math_mode (str): The Metal math mode to compile the kernel with: + ``"safe"``, ``"relaxed"``, or ``"fast"``. ``"safe"`` preserves + IEEE behavior for special values such as ``exp(-inf) == 0``. + Default: ``"safe"``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 87e7b31ced..971f8fc264 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -638,6 +638,38 @@ def call_cpu(a): with self.assertRaisesRegex(RuntimeError, "No Metal back-end"): mx.eval(mx.compile(call)(a)) + def test_export_custom_metal_kernel_with_math_mode(self): + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(a[elem]); + """ + kernel = mx.fast.metal_kernel( + name="math_mode_export", + input_names=["a"], + output_names=["out"], + source=source, + math_mode="safe", + ) + + def call(a): + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(min(a.size, 256), 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.array([-float("inf"), 0.0]) + path = os.path.join(self.test_dir, "metal_kernel_math_mode.mlxfn") + mx.export_function(path, call, a) + self.assertTrue(os.path.exists(path)) + + if mx.metal.is_available(): + imported = mx.import_function(path) + self.assertTrue(mx.array_equal(imported(a)[0], call(a))) + def test_export_import_multi_with_constants(self): path = os.path.join(self.test_dir, "fn.mlxfn") diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b1c84c987d..69e07552e9 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1026,6 +1026,39 @@ def call_kernel(a: mx.array, source): out = call_kernel(a, source) self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + def test_custom_metal_kernel_invalid_math_mode(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_math_mode", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + math_mode="precise", + ) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_metal_kernel_safe_math_mode(self): + kernel = mx.fast.metal_kernel( + name="safe_math_mode", + input_names=["inp"], + output_names=["out"], + source=""" + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(inp[elem]); + """, + math_mode="safe", + ) + a = mx.array([-float("inf"), 0.0], dtype=mx.float32) + out = kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + self.assertTrue(mx.array_equal(out, mx.array([0.0, 1.0]))) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_mixed_dtypes(self): # Calling the same kernel with different input dtypes in a single From 58cf13a410b6965e6f23439ba0e0188038840133 Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 13:09:06 +0530 Subject: [PATCH 2/3] Refactor Metal compile options --- docs/src/dev/custom_metal_kernels.rst | 14 +++++---- mlx/backend/common/metal_kernel.cpp | 4 +-- mlx/backend/metal/custom_kernel.cpp | 3 +- mlx/backend/metal/device.cpp | 17 ++++++----- mlx/backend/metal/device.h | 8 +++-- mlx/fast.h | 6 +++- python/src/fast.cpp | 44 ++++++++++++++++++++++----- python/tests/test_export_import.py | 2 +- python/tests/test_fast.py | 14 +++++++-- 9 files changed, 82 insertions(+), 30 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 9629475619..ce5880952f 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -53,13 +53,15 @@ JIT compiled. To reduce the overhead from that, build the kernel once with Math Mode --------- -By default :func:`fast.metal_kernel` compiles kernels with ``math_mode="safe"`` -so special values follow IEEE behavior, for example ``exp(-inf) == 0``. This is -important for kernels such as masked softmax where causal or sliding-window -masks depend on exponentiating ``-inf``. +By default :func:`fast.metal_kernel` compiles kernels with +``compile_options={"math_mode": "safe"}`` so special values follow IEEE +behavior, for example ``exp(-inf) == 0``. This is important for kernels such as +masked softmax where causal or sliding-window masks depend on exponentiating +``-inf``. If your kernel does not rely on these edge cases, you can opt in to less strict -math with ``math_mode="relaxed"`` or ``math_mode="fast"``: +math with ``compile_options={"math_mode": "relaxed"}`` or +``compile_options={"math_mode": "fast"}``: .. code-block:: python @@ -68,7 +70,7 @@ math with ``math_mode="relaxed"`` or ``math_mode="fast"``: input_names=["x"], output_names=["y"], source=source, - math_mode="relaxed", + compile_options={"math_mode": "relaxed"}, ) The full function signature will be generated using: diff --git a/mlx/backend/common/metal_kernel.cpp b/mlx/backend/common/metal_kernel.cpp index af09cacd3b..edbf75c46e 100644 --- a/mlx/backend/common/metal_kernel.cpp +++ b/mlx/backend/common/metal_kernel.cpp @@ -207,7 +207,7 @@ CustomKernelFunction metal_kernel( const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, bool atomic_outputs /* = false */, - MetalKernelMathMode math_mode /* = MetalKernelMathMode::Safe */) { + CompileOptions compile_options /* = {} */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); @@ -362,7 +362,7 @@ CustomKernelFunction metal_kernel( std::vector{}, false, 0, - static_cast(math_mode)), + static_cast(compile_options.math_mode)), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 23bb8891bd..1e47428088 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" namespace mlx::core::fast { @@ -90,7 +91,7 @@ void CustomKernel::eval_gpu( auto lib = d.get_library( name_, [this] { return metal::utils() + source_; }, - to_mtl_math_mode(metal_math_mode_)); + metal::CompileOptions{to_mtl_math_mode(metal_math_mode_)}); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e915bbec8d..de7ad6d159 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -607,7 +607,7 @@ MTL::Library* Device::get_library( NS::SharedPtr Device::build_library_( const std::string& source_string, - std::optional math_mode) { + CompileOptions compile_options) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -615,16 +615,17 @@ NS::SharedPtr Device::build_library_( NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init()->autorelease(); - if (math_mode) { + if (compile_options.math_mode) { + auto math_mode = *compile_options.math_mode; if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { - options->setMathMode(*math_mode); + options->setMathMode(math_mode); } else { - if (*math_mode == MTL::MathModeRelaxed) { + if (math_mode == MTL::MathModeRelaxed) { throw std::runtime_error( "[metal::Device] Metal math mode `relaxed` requires macOS 15, " "iOS 18, tvOS 18, or visionOS 2."); } - options->setFastMathEnabled(*math_mode == MTL::MathModeFast); + options->setFastMathEnabled(math_mode == MTL::MathModeFast); } } else { options->setFastMathEnabled(false); @@ -770,13 +771,13 @@ NS::SharedPtr Device::get_kernel_( MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { - return get_library(name, builder, std::nullopt); + return get_library(name, builder, {}); } MTL::Library* Device::get_library( const std::string& name, const std::function& builder, - std::optional math_mode) { + CompileOptions compile_options) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { @@ -789,7 +790,7 @@ MTL::Library* Device::get_library( return it->second.get(); } - auto mtl_lib = build_library_(builder(), math_mode); + auto mtl_lib = build_library_(builder(), compile_options); library_map_.insert({name, mtl_lib}); return mtl_lib.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d9cb910c4e..e04cb24228 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -20,6 +20,10 @@ namespace mlx::core::metal { using MTLFCList = std::vector>; +struct CompileOptions { + std::optional math_mode = std::nullopt; +}; + class Device; class EventImpl; @@ -174,7 +178,7 @@ class MLX_API Device { MTL::Library* get_library( const std::string& name, const std::function& builder, - std::optional math_mode); + CompileOptions compile_options); void clear_library(const std::string& name); @@ -198,7 +202,7 @@ class MLX_API Device { private: NS::SharedPtr build_library_( const std::string& source_string, - std::optional math_mode = std::nullopt); + CompileOptions compile_options = {}); NS::SharedPtr get_function_( const std::string& name, diff --git a/mlx/fast.h b/mlx/fast.h index 58b4ad8a7d..2e70ba8ac5 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -63,6 +63,10 @@ enum class MetalKernelMathMode { Fast = 2, }; +struct CompileOptions { + MetalKernelMathMode math_mode = MetalKernelMathMode::Safe; +}; + using CustomKernelFunction = std::function( const std::vector&, const std::vector&, @@ -82,7 +86,7 @@ MLX_API CustomKernelFunction metal_kernel( const std::string& header = "", bool ensure_row_contiguous = true, bool atomic_outputs = false, - MetalKernelMathMode math_mode = MetalKernelMathMode::Safe); + CompileOptions compile_options = {}); MLX_API CustomKernelFunction cuda_kernel( const std::string& name, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index fc8af5a9da..b24732fc06 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -88,6 +88,34 @@ mx::fast::MetalKernelMathMode parse_metal_math_mode( "[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'."); } +mx::fast::CompileOptions parse_compile_options( + const nb::object& compile_options) { + mx::fast::CompileOptions options; + + if (compile_options.is_none()) { + return options; + } + + if (!nb::isinstance(compile_options)) { + throw std::invalid_argument( + "[metal_kernel] Expected `compile_options` to be a dict."); + } + + nb::dict dict = nb::cast(compile_options); + for (auto [key, value] : dict) { + auto key_str = nb::cast(key); + if (key_str == "math_mode") { + options.math_mode = + parse_metal_math_mode(nb::cast(value)); + } else { + std::ostringstream msg; + msg << "[metal_kernel] Unknown compile option `" << key_str << "`."; + throw std::invalid_argument(msg.str()); + } + } + return options; +} + } // namespace void init_fast(nb::module_& parent_module) { @@ -318,7 +346,7 @@ void init_fast(nb::module_& parent_module) { const std::string& header, bool ensure_row_contiguous, bool atomic_outputs, - const std::string& math_mode) { + const nb::object& compile_options) { auto kernel = mx::fast::metal_kernel( name, input_names, @@ -327,7 +355,7 @@ void init_fast(nb::module_& parent_module) { header, ensure_row_contiguous, atomic_outputs, - parse_metal_math_mode(math_mode)); + parse_compile_options(compile_options)); return nb::cpp_function( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), @@ -371,7 +399,7 @@ void init_fast(nb::module_& parent_module) { "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, - "math_mode"_a = "safe", + "compile_options"_a = nb::none(), R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. @@ -392,10 +420,12 @@ void init_fast(nb::module_& parent_module) { before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature e.g. ``device atomic``. Default: ``False``. - math_mode (str): The Metal math mode to compile the kernel with: - ``"safe"``, ``"relaxed"``, or ``"fast"``. ``"safe"`` preserves - IEEE behavior for special values such as ``exp(-inf) == 0``. - Default: ``"safe"``. + compile_options (dict, optional): Options to compile the Metal kernel + with. Supported options: + + * ``"math_mode"``: The Metal math mode: ``"safe"``, ``"relaxed"``, + or ``"fast"``. ``"safe"`` preserves IEEE behavior for special + values such as ``exp(-inf) == 0``. Default: ``"safe"``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 971f8fc264..c45dd4a606 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -648,7 +648,7 @@ def test_export_custom_metal_kernel_with_math_mode(self): input_names=["a"], output_names=["out"], source=source, - math_mode="safe", + compile_options={"math_mode": "safe"}, ) def call(a): diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 69e07552e9..e2baf93f95 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1033,7 +1033,17 @@ def test_custom_metal_kernel_invalid_math_mode(self): input_names=["inp"], output_names=["out"], source="out[0] = inp[0];", - math_mode="precise", + compile_options={"math_mode": "precise"}, + ) + + def test_custom_metal_kernel_invalid_compile_options(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_compile_options", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + compile_options={"unknown": "value"}, ) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @@ -1046,7 +1056,7 @@ def test_custom_metal_kernel_safe_math_mode(self): uint elem = thread_position_in_grid.x; out[elem] = metal::exp(inp[elem]); """, - math_mode="safe", + compile_options={"math_mode": "safe"}, ) a = mx.array([-float("inf"), 0.0], dtype=mx.float32) out = kernel( From 7477e99343822b3dc245a7df86547e48c9dd8e37 Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 16:33:08 +0530 Subject: [PATCH 3/3] Apply clang-format to fast bindings --- python/src/fast.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b24732fc06..af764ad33f 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -105,8 +105,7 @@ mx::fast::CompileOptions parse_compile_options( for (auto [key, value] : dict) { auto key_str = nb::cast(key); if (key_str == "math_mode") { - options.math_mode = - parse_metal_math_mode(nb::cast(value)); + options.math_mode = parse_metal_math_mode(nb::cast(value)); } else { std::ostringstream msg; msg << "[metal_kernel] Unknown compile option `" << key_str << "`.";