Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/python/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Operations
array_equal
asarray
as_strided
astype
atleast_1d
atleast_2d
atleast_3d
Expand Down Expand Up @@ -59,19 +60,25 @@ Operations
conv_general
cos
cosh
count_nonzero
cummax
cummin
cumprod
cumsum
cumulative_prod
cumulative_sum
degrees
depends
dequantize
diag
diagonal
diff
divide
divmod
einsum
einsum_path
empty
empty_like
equal
erf
erfinv
Expand All @@ -83,6 +90,7 @@ Operations
floor
floor_divide
full
full_like
from_fp8
gather_mm
gather_qmm
Expand Down Expand Up @@ -116,8 +124,10 @@ Operations
logical_not
logical_and
logical_or
logical_xor
logsumexp
matmul
matrix_transpose
max
maximum
mean
Expand All @@ -136,6 +146,7 @@ Operations
partition
pad
permute_dims
positive
power
prod
put_along_axis
Expand Down Expand Up @@ -189,6 +200,7 @@ Operations
tri
tril
triu
trunc
unflatten
var
view
Expand Down
128 changes: 128 additions & 0 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,53 @@ class ArrayPythonIterator {
std::vector<mx::array> splits_;
};

// Returned by ``__array_namespace_info__()``; exposes array API inspection
// information about the namespace.
struct ArrayNamespaceInfo {};

namespace {

// Whether ``dtype`` matches an array API dtype "kind" (a dtype, a kind string,
// or a tuple of those). A null/None kind matches everything.
bool dtype_matches_kind(const mx::Dtype& dtype, const nb::handle& kind) {
if (kind.is_none()) {
return true;
}
if (nb::isinstance<nb::tuple>(kind)) {
for (auto k : nb::cast<nb::tuple>(kind)) {
if (dtype_matches_kind(dtype, k)) {
return true;
}
}
return false;
}
if (nb::isinstance<mx::Dtype>(kind)) {
return dtype == nb::cast<mx::Dtype>(kind);
}
auto s = nb::cast<std::string>(kind);
if (s == "bool") {
return dtype == mx::bool_;
} else if (s == "signed integer") {
return mx::issubdtype(dtype, mx::signedinteger);
} else if (s == "unsigned integer") {
return mx::issubdtype(dtype, mx::unsignedinteger);
} else if (s == "integral") {
return mx::issubdtype(dtype, mx::integer);
} else if (s == "real floating") {
return mx::issubdtype(dtype, mx::floating);
} else if (s == "complex floating") {
return mx::issubdtype(dtype, mx::complexfloating);
} else if (s == "numeric") {
return mx::issubdtype(dtype, mx::number);
}
std::ostringstream msg;
msg << "[__array_namespace_info__.dtypes] Unknown data type kind: '" << s
<< "'.";
throw std::invalid_argument(msg.str());
}

} // namespace

void init_array(nb::module_& m) {
// Types
nb::class_<mx::Dtype>(
Expand Down Expand Up @@ -248,6 +295,87 @@ void init_array(nb::module_& m) {
return os.str();
});

nb::class_<ArrayNamespaceInfo>(
m,
"__array_namespace_info__",
R"pbdoc(
Array API namespace inspection utilities.

Returned by ``array.__array_namespace__().__array_namespace_info__()``.
See the `array API <https://data-apis.org/array-api/latest/>`_ for
details.
)pbdoc")
.def(nb::init<>())
.def(
"capabilities",
[](const ArrayNamespaceInfo&) {
nb::dict d;
d["boolean indexing"] = true;
d["data-dependent shapes"] = false;
d["max dimensions"] = nb::none();
return d;
},
R"pbdoc(The capabilities of the namespace.)pbdoc")
.def(
"default_device",
[](const ArrayNamespaceInfo&) { return mx::default_device(); },
R"pbdoc(The default device.)pbdoc")
.def(
"default_dtypes",
[](const ArrayNamespaceInfo&, const nb::object&) {
nb::dict d;
d["real floating"] = nb::cast(mx::float32);
d["complex floating"] = nb::cast(mx::complex64);
d["integral"] = nb::cast(mx::int32);
d["indexing"] = nb::cast(mx::int32);
return d;
},
"device"_a = nb::none(),
R"pbdoc(The default data types of the namespace.)pbdoc")
.def(
"devices",
[](const ArrayNamespaceInfo&) {
nb::list l;
l.append(mx::Device(mx::Device::cpu));
if (mx::is_available(mx::Device(mx::Device::gpu))) {
l.append(mx::Device(mx::Device::gpu));
}
return l;
},
R"pbdoc(The devices supported by the namespace.)pbdoc")
.def(
"dtypes",
[](const ArrayNamespaceInfo&,
const nb::object&,
const nb::object& kind) {
const std::pair<const char*, mx::Dtype> all[] = {
{"bool", mx::bool_},
{"int8", mx::int8},
{"int16", mx::int16},
{"int32", mx::int32},
{"int64", mx::int64},
{"uint8", mx::uint8},
{"uint16", mx::uint16},
{"uint32", mx::uint32},
{"uint64", mx::uint64},
{"float16", mx::float16},
{"bfloat16", mx::bfloat16},
{"float32", mx::float32},
{"float64", mx::float64},
{"complex64", mx::complex64},
};
nb::dict d;
for (const auto& [name, dtype] : all) {
if (dtype_matches_kind(dtype, kind)) {
d[name] = nb::cast(dtype);
}
}
return d;
},
"device"_a = nb::none(),
"kind"_a = nb::none(),
R"pbdoc(The data types supported by the namespace.)pbdoc");

nb::class_<ArrayAt>(
m,
"ArrayAt",
Expand Down
Loading