diff --git a/docs/src/usage/export.rst b/docs/src/usage/export.rst index ac1ec218bb..c349437cff 100644 --- a/docs/src/usage/export.rst +++ b/docs/src/usage/export.rst @@ -109,6 +109,32 @@ keyword arguments when calling the imported function. out, = imported_fun(x, z=y) +Saving Metadata +--------------- + +You can save additional metadata, such as a model configuration, alongside an +exported function. The metadata is a dictionary of string keys and string +values: + +.. code-block:: python + + def fun(x, y): + return x + y + + metadata = {"description": "adds two arrays", "version": "1.0"} + mx.export_function("add.mlxfn", fun, x, y, metadata=metadata) + +Read the metadata back when importing the function by passing +``return_metadata=True``: + +.. code-block:: python + + imported_fun, metadata = mx.import_function("add.mlxfn", return_metadata=True) + + # Prints: {'description': 'adds two arrays', 'version': '1.0'} + print(metadata) + + Exporting Modules ----------------- diff --git a/mlx/export.cpp b/mlx/export.cpp index 21d996bd5a..8e6f895452 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -520,10 +520,15 @@ struct PrimitiveFactory { }; }; -void write_header(Writer& os, int count, bool shapeless) { +void write_header( + Writer& os, + int count, + bool shapeless, + const std::unordered_map& metadata) { serialize(os, std::string(version())); serialize(os, count); serialize(os, shapeless); + serialize(os, metadata); } // A struct to hold and retrieve the graphs that are exported / imported @@ -673,14 +678,16 @@ FunctionTable::Function* FunctionTable::find( FunctionExporter::FunctionExporter( const std::string& file, std::function(const Args&, const Kwargs&)> fun, - bool shapeless) + bool shapeless, + std::unordered_map metadata) : os(file), fun(std::move(fun)), - ftable(std::make_shared(shapeless)) { + ftable(std::make_shared(shapeless)), + metadata_(std::move(metadata)) { if (!os.is_open()) { throw std::runtime_error("[export_function] Failed to open " + file); } - write_header(os, count, shapeless); + write_header(os, count, shapeless, metadata_); } FunctionExporter::FunctionExporter( @@ -812,7 +819,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { // Update the header auto pos = os.tell(); os.seek(0); - write_header(os, count, ftable->shapeless); + write_header(os, count, ftable->shapeless, metadata_); os.seek(pos); serialize(os, kwarg_keys); @@ -891,44 +898,51 @@ void FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) { FunctionExporter exporter( const std::string& file, const std::function(const Args&)>& fun, - bool shapeless /* = false */) { + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { return FunctionExporter{ file, [fun](const Args& args, const Kwargs&) { return fun(args); }, - shapeless}; + shapeless, + metadata}; } FunctionExporter exporter( const std::string& file, const std::function(const Kwargs&)>& fun, - bool shapeless /* = false */) { + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { return exporter( file, [fun](const Args&, const Kwargs kwargs) { return fun(kwargs); }, - shapeless); + shapeless, + metadata); } FunctionExporter exporter( const std::string& file, const std::function(const Args&, const Kwargs&)>& fun, - bool shapeless /* = false */) { - return FunctionExporter{file, fun, shapeless}; + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { + return FunctionExporter{file, fun, shapeless, metadata}; } void export_function( const std::string& file, const std::function(const Args&)>& fun, const Args& args, - bool shapeless /* = false */) { - exporter(file, fun, shapeless)(args); + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { + exporter(file, fun, shapeless, metadata)(args); } void export_function( const std::string& file, const std::function(const Kwargs&)>& fun, const Kwargs& kwargs, - bool shapeless /* = false */) { - exporter(file, fun, shapeless)(kwargs); + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { + exporter(file, fun, shapeless, metadata)(kwargs); } void export_function( @@ -936,8 +950,9 @@ void export_function( const std::function(const Args&, const Kwargs&)>& fun, const Args& args, const Kwargs& kwargs, - bool shapeless /* = false */) { - exporter(file, fun, shapeless)(args, kwargs); + bool shapeless /* = false */, + const std::unordered_map& metadata /* = {} */) { + exporter(file, fun, shapeless, metadata)(args, kwargs); } FunctionExporter exporter( @@ -1047,6 +1062,10 @@ ImportedFunction::ImportedFunction(const std::string& file) auto mlx_version = deserialize(is); auto function_count = deserialize(is); ftable->shapeless = deserialize(is); + auto metadata_pairs = + deserialize>>(is); + metadata_ = std::unordered_map( + metadata_pairs.begin(), metadata_pairs.end()); std::unordered_map constants; auto import_one = [&]() { diff --git a/mlx/export.h b/mlx/export.h index 5532f7c818..89bded4047 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -50,17 +50,20 @@ struct FunctionExporter; MLX_API FunctionExporter exporter( const std::string& file, const std::function(const Args&)>& fun, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); MLX_API FunctionExporter exporter( const std::string& file, const std::function(const Kwargs&)>& fun, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); MLX_API FunctionExporter exporter( const std::string& path, const std::function(const Args&, const Kwargs&)>& fun, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); /** * Export a function to a file. @@ -69,20 +72,23 @@ MLX_API void export_function( const std::string& file, const std::function(const Args&)>& fun, const Args& args, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); MLX_API void export_function( const std::string& file, const std::function(const Kwargs&)>& fun, const Kwargs& kwargs, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); MLX_API void export_function( const std::string& file, const std::function(const Args&, const Kwargs&)>& fun, const Args& args, const Kwargs& kwargs, - bool shapeless = false); + bool shapeless = false, + const std::unordered_map& metadata = {}); struct ImportedFunction; diff --git a/mlx/export_impl.h b/mlx/export_impl.h index 467a5f0d6c..c0fa2421aa 100644 --- a/mlx/export_impl.h +++ b/mlx/export_impl.h @@ -27,17 +27,20 @@ struct MLX_API FunctionExporter { friend MLX_API FunctionExporter exporter( const std::string&, const std::function(const Args&)>&, - bool shapeless); + bool shapeless, + const std::unordered_map& metadata); friend MLX_API FunctionExporter exporter( const std::string&, const std::function(const Kwargs&)>&, - bool shapeless); + bool shapeless, + const std::unordered_map& metadata); friend MLX_API FunctionExporter exporter( const std::string&, const std::function(const Args&, const Kwargs&)>&, - bool shapeless); + bool shapeless, + const std::unordered_map& metadata); friend MLX_API FunctionExporter exporter( const ExportCallback&, @@ -57,7 +60,8 @@ struct MLX_API FunctionExporter { FunctionExporter( const std::string& file, std::function(const Args&, const Kwargs&)> fun, - bool shapeless); + bool shapeless, + std::unordered_map metadata); FunctionExporter( const ExportCallback& callback, @@ -77,6 +81,7 @@ struct MLX_API FunctionExporter { int count{0}; bool closed{false}; std::shared_ptr ftable; + std::unordered_map metadata_; }; struct MLX_API ImportedFunction { @@ -88,12 +93,18 @@ struct MLX_API ImportedFunction { std::vector operator()(const Kwargs& kwargs) const; std::vector operator()(const Args& args, const Kwargs& kwargs) const; + // The metadata stored alongside the function when it was exported. + const std::unordered_map& metadata() const { + return metadata_; + } + private: ImportedFunction(const std::string& file); friend MLX_API ImportedFunction import_function(const std::string&); ImportedFunction(); std::shared_ptr ftable; + std::unordered_map metadata_; }; } // namespace mlx::core diff --git a/python/src/export.cpp b/python/src/export.cpp index 51c8a8965c..d857c086ff 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -138,6 +138,8 @@ void init_export(nb::module_& m) { const nb::callable& fun, const nb::args& args, bool shapeless, + const std::optional>& + metadata, const nb::kwargs& kwargs) { auto [args_, kwargs_] = validate_and_extract_inputs(args, kwargs, "[export_function]"); @@ -147,8 +149,15 @@ void init_export(nb::module_& m) { wrap_export_function(fun), args_, kwargs_, - shapeless); + shapeless, + metadata.value_or( + std::unordered_map{})); } else { + if (metadata && !metadata->empty()) { + throw std::invalid_argument( + "[export_function] The metadata argument is only supported " + "when exporting to a file, not when using a callback."); + } auto callback = nb::cast(file_or_callback); auto wrapped_callback = [callback](const mx::ExportCallbackInput& input) { @@ -163,9 +172,10 @@ void init_export(nb::module_& m) { "args"_a, nb::kw_only(), "shapeless"_a = false, + "metadata"_a = nb::none(), "kwargs"_a, nb::sig( - "def export_function(file_or_callback: Union[str, Callable], fun: Callable, *args, shapeless: bool = False, **kwargs) -> None"), + "def export_function(file_or_callback: Union[str, Callable], fun: Callable, *args, shapeless: bool = False, metadata: Optional[dict[str, str]] = None, **kwargs) -> None"), R"pbdoc( Export an MLX function. @@ -187,6 +197,10 @@ void init_export(nb::module_& m) { *args (array): Example array inputs to the function. shapeless (bool, optional): Whether or not the function allows inputs with variable shapes. Default: ``False``. + metadata (dict, optional): A dictionary of string keys and string + values to save alongside the function. Only supported when + exporting to a file. The metadata can be read back with + :func:`import_function`. Default: ``None``. **kwargs (array): Additional example keyword array inputs to the function. @@ -203,17 +217,26 @@ void init_export(nb::module_& m) { )pbdoc"); m.def( "import_function", - [](const std::string& file) { - return nb::cpp_function( - [fn = mx::import_function(file)]( + [](const std::string& file, bool return_metadata) -> nb::object { + auto imported = mx::import_function(file); + auto metadata = imported.metadata(); + auto fn = nb::cpp_function( + [imported = std::move(imported)]( const nb::args& args, const nb::kwargs& kwargs) { auto [args_, kwargs_] = validate_and_extract_inputs( args, kwargs, "[import_function::call]"); - return nb::tuple(nb::cast(fn(args_, kwargs_))); + return nb::tuple(nb::cast(imported(args_, kwargs_))); }); + if (return_metadata) { + return nb::make_tuple(fn, nb::cast(metadata)); + } + return fn; }, "file"_a, - nb::sig("def import_function(file: str) -> Callable"), + nb::kw_only(), + "return_metadata"_a = false, + nb::sig( + "def import_function(file: str, *, return_metadata: bool = False) -> Union[Callable, tuple[Callable, dict[str, str]]]"), R"pbdoc( Import a function from a file. @@ -230,9 +253,14 @@ void init_export(nb::module_& m) { Args: file (str): The file path to import the function from. + return_metadata (bool, optional): If ``True`` also return the + metadata that was saved with the function as a dictionary of + string keys and values. Default: ``False``. Returns: - Callable: The imported function. + Callable: The imported function. If ``return_metadata`` is + ``True`` a tuple of the imported function and a dictionary of + metadata is returned instead. Example: >>> fn = mx.import_function("function.mlxfn") @@ -274,14 +302,25 @@ void init_export(nb::module_& m) { m.def( "exporter", - [](const std::string& file, nb::callable fun, bool shapeless) { + [](const std::string& file, + nb::callable fun, + bool shapeless, + const std::optional>& + metadata) { return PyFunctionExporter{ - mx::exporter(file, wrap_export_function(fun), shapeless), fun}; + mx::exporter( + file, + wrap_export_function(fun), + shapeless, + metadata.value_or( + std::unordered_map{})), + fun}; }, "file"_a, "fun"_a, nb::kw_only(), "shapeless"_a = false, + "metadata"_a = nb::none(), R"pbdoc( Make a callable object to export multiple traces of a function to a file. @@ -295,6 +334,9 @@ void init_export(nb::module_& m) { file (str): File path to export the function to. shapeless (bool, optional): Whether or not the function allows inputs with variable shapes. Default: ``False``. + metadata (dict, optional): A dictionary of string keys and string + values to save alongside the function. The metadata can be read + back with :func:`import_function`. Default: ``None``. Example: diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 6f2b2ca496..e54ed6d7d7 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -615,6 +615,46 @@ def fun(x, y, z): imported = mx.import_function(path) self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z))) + def test_export_import_metadata(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(x): + return mx.abs(x) + + x = mx.array([1.0, -2.0, 3.0]) + metadata = {"model": "test", "version": "1.0"} + + mx.export_function(path, fun, x, metadata=metadata) + + # By default no metadata is returned and the callable still works. + imported = mx.import_function(path) + self.assertTrue(mx.array_equal(imported(x)[0], fun(x))) + + # Request the metadata back. + imported, imported_metadata = mx.import_function(path, return_metadata=True) + self.assertEqual(imported_metadata, metadata) + self.assertTrue(mx.array_equal(imported(x)[0], fun(x))) + + # No metadata gives back an empty dictionary. + mx.export_function(path, fun, x) + _, imported_metadata = mx.import_function(path, return_metadata=True) + self.assertEqual(imported_metadata, {}) + + # Metadata also works with the exporter context manager across + # multiple traces (the header is rewritten on each trace). + with mx.exporter(path, fun, metadata=metadata) as exporter: + exporter(mx.array([1.0])) + exporter(mx.array([1.0, 2.0])) + exporter(mx.array([1.0, 2.0, 3.0])) + imported, imported_metadata = mx.import_function(path, return_metadata=True) + self.assertEqual(imported_metadata, metadata) + for y in (mx.array([1.0]), mx.array([1.0, 2.0, 3.0])): + self.assertTrue(mx.array_equal(imported(y)[0], fun(y))) + + # Metadata is not supported with a callback. + with self.assertRaises(ValueError): + mx.export_function(lambda x: None, fun, x, metadata=metadata) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/tests/export_import_tests.cpp b/tests/export_import_tests.cpp index ef6a18e199..187d993e54 100644 --- a/tests/export_import_tests.cpp +++ b/tests/export_import_tests.cpp @@ -161,3 +161,24 @@ TEST_CASE("test export function on different stream") { // Should make a new stream that we can run computation on eval(import_function(file_path)({array({0, 1, 2})})); } + +TEST_CASE("test export import with metadata") { + std::string file_path = get_temp_file("model.mlxfn"); + + auto fun = [](const std::vector& args) -> std::vector { + return {abs(args[0])}; + }; + + std::unordered_map metadata = { + {"model", "test"}, {"version", "1.0"}}; + + export_function(file_path, fun, {array({0, 1, 2})}, false, metadata); + + auto imported_fun = import_function(file_path); + CHECK(imported_fun.metadata() == metadata); + eval(imported_fun({array({0, 1, 2})})); + + // With no metadata the imported map is empty. + export_function(file_path, fun, {array({0, 1, 2})}); + CHECK(import_function(file_path).metadata().empty()); +}