Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/src/usage/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ keyword arguments when calling the imported function.
out, = imported_fun(x, z=y)


Saving Metadata
---------------

You can save additional metadata, such as a model configuration, alongside an
exported function. The metadata is a dictionary of string keys and string
values:

.. code-block:: python

def fun(x, y):
return x + y

metadata = {"description": "adds two arrays", "version": "1.0"}
mx.export_function("add.mlxfn", fun, x, y, metadata=metadata)

Read the metadata back when importing the function by passing
``return_metadata=True``:

.. code-block:: python

imported_fun, metadata = mx.import_function("add.mlxfn", return_metadata=True)

# Prints: {'description': 'adds two arrays', 'version': '1.0'}
print(metadata)


Exporting Modules
-----------------

Expand Down
53 changes: 36 additions & 17 deletions mlx/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,15 @@ struct PrimitiveFactory {
};
};

void write_header(Writer& os, int count, bool shapeless) {
void write_header(
Writer& os,
int count,
bool shapeless,
const std::unordered_map<std::string, std::string>& metadata) {
serialize(os, std::string(version()));
serialize(os, count);
serialize(os, shapeless);
serialize(os, metadata);
}

// A struct to hold and retrieve the graphs that are exported / imported
Expand Down Expand Up @@ -673,14 +678,16 @@ FunctionTable::Function* FunctionTable::find(
FunctionExporter::FunctionExporter(
const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless)
bool shapeless,
std::unordered_map<std::string, std::string> metadata)
: os(file),
fun(std::move(fun)),
ftable(std::make_shared<FunctionTable>(shapeless)) {
ftable(std::make_shared<FunctionTable>(shapeless)),
metadata_(std::move(metadata)) {
if (!os.is_open()) {
throw std::runtime_error("[export_function] Failed to open " + file);
}
write_header(os, count, shapeless);
write_header(os, count, shapeless, metadata_);
}

FunctionExporter::FunctionExporter(
Expand Down Expand Up @@ -812,7 +819,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Update the header
auto pos = os.tell();
os.seek(0);
write_header(os, count, ftable->shapeless);
write_header(os, count, ftable->shapeless, metadata_);
os.seek(pos);
serialize(os, kwarg_keys);

Expand Down Expand Up @@ -891,53 +898,61 @@ void FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) {
FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless /* = false */) {
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
return FunctionExporter{
file,
[fun](const Args& args, const Kwargs&) { return fun(args); },
shapeless};
shapeless,
metadata};
}

FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless /* = false */) {
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
return exporter(
file,
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
shapeless);
shapeless,
metadata);
}

FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{file, fun, shapeless};
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
return FunctionExporter{file, fun, shapeless, metadata};
}

void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(args);
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
exporter(file, fun, shapeless, metadata)(args);
}

void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(kwargs);
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
exporter(file, fun, shapeless, metadata)(kwargs);
}

void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(file, fun, shapeless)(args, kwargs);
bool shapeless /* = false */,
const std::unordered_map<std::string, std::string>& metadata /* = {} */) {
exporter(file, fun, shapeless, metadata)(args, kwargs);
}

FunctionExporter exporter(
Expand Down Expand Up @@ -1047,6 +1062,10 @@ ImportedFunction::ImportedFunction(const std::string& file)
auto mlx_version = deserialize<std::string>(is);
auto function_count = deserialize<int>(is);
ftable->shapeless = deserialize<bool>(is);
auto metadata_pairs =
deserialize<std::vector<std::pair<std::string, std::string>>>(is);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen when importing functions exported from previous versions without metadata support? (There is no need to be compatible but I'm curious about the behavior)

metadata_ = std::unordered_map<std::string, std::string>(
metadata_pairs.begin(), metadata_pairs.end());
std::unordered_map<std::uintptr_t, array> constants;

auto import_one = [&]() {
Expand Down
18 changes: 12 additions & 6 deletions mlx/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ struct FunctionExporter;
MLX_API FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

MLX_API FunctionExporter exporter(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

MLX_API FunctionExporter exporter(
const std::string& path,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

/**
* Export a function to a file.
Expand All @@ -69,20 +72,23 @@ MLX_API void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

MLX_API void export_function(
const std::string& file,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

MLX_API void export_function(
const std::string& file,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
bool shapeless = false,
const std::unordered_map<std::string, std::string>& metadata = {});

struct ImportedFunction;

Expand Down
19 changes: 15 additions & 4 deletions mlx/export_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ struct MLX_API FunctionExporter {
friend MLX_API FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Args&)>&,
bool shapeless);
bool shapeless,
const std::unordered_map<std::string, std::string>& metadata);

friend MLX_API FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Kwargs&)>&,
bool shapeless);
bool shapeless,
const std::unordered_map<std::string, std::string>& metadata);

friend MLX_API FunctionExporter exporter(
const std::string&,
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless);
bool shapeless,
const std::unordered_map<std::string, std::string>& metadata);

friend MLX_API FunctionExporter exporter(
const ExportCallback&,
Expand All @@ -57,7 +60,8 @@ struct MLX_API FunctionExporter {
FunctionExporter(
const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless);
bool shapeless,
std::unordered_map<std::string, std::string> metadata);

FunctionExporter(
const ExportCallback& callback,
Expand All @@ -77,6 +81,7 @@ struct MLX_API FunctionExporter {
int count{0};
bool closed{false};
std::shared_ptr<FunctionTable> ftable;
std::unordered_map<std::string, std::string> metadata_;
};

struct MLX_API ImportedFunction {
Expand All @@ -88,12 +93,18 @@ struct MLX_API ImportedFunction {
std::vector<array> operator()(const Kwargs& kwargs) const;
std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;

// The metadata stored alongside the function when it was exported.
const std::unordered_map<std::string, std::string>& metadata() const {
return metadata_;
}

private:
ImportedFunction(const std::string& file);
friend MLX_API ImportedFunction import_function(const std::string&);
ImportedFunction();

std::shared_ptr<FunctionTable> ftable;
std::unordered_map<std::string, std::string> metadata_;
};

} // namespace mlx::core
Loading
Loading