diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7658ce5f5c..2a8e15afd7 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -162,6 +162,20 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { + // Check override path before automatic lookup + if (!get_metallib_path().empty()) { + auto [lib, error] = + load_library_from_path(device, get_metallib_path().c_str()); + if (!lib) { + throw std::runtime_error( + fmt::format( + "Can not load metallib from specified location \"{}\": {}.", + get_metallib_path(), + error->localizedDescription()->utf8String())); + } + return lib; + } + NS::Error* error[5]; MTL::Library* lib; // First try the colocated mlx.metallib diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 6bf3b895b9..e95b6f0b9e 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -7,6 +7,12 @@ namespace mlx::core::metal { +namespace { + +std::string g_metallib_path; + +} // namespace + bool is_available() { return true; } @@ -46,4 +52,12 @@ void stop_capture() { manager->stopCapture(); } +void set_metallib_path(const std::string& path) { + g_metallib_path = path; +} + +const std::string& get_metallib_path() { + return g_metallib_path; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 6662e21ebd..d686c2971a 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -22,4 +22,9 @@ MLX_API const std::unordered_map>& device_info(); +/* Set a custom path to mlx.metallib. Must be called before any MLX operation. + */ +MLX_API void set_metallib_path(const std::string& path); +MLX_API const std::string& get_metallib_path(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index 2d3414cc50..3b72ad005a 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -21,6 +21,13 @@ device_info() { "[metal::device_info] Cannot get device info without metal backend"); }; +void set_metallib_path(const std::string& path) {} + +const std::string& get_metallib_path() { + throw std::runtime_error( + "[metal::get_metallib_path] Cannot get metallib path without metal backend"); +} + } // namespace metal } // namespace mlx::core