From 5a07f426568759a4f63029e9a7871f210f7434df Mon Sep 17 00:00:00 2001 From: katlun-lgtm Date: Sun, 14 Jun 2026 13:31:40 -0400 Subject: [PATCH] feat: add more array API functions Adds array API functions toward #3484, all built on existing primitives (no core changes): - Elementwise / utility: positive, logical_xor, trunc, count_nonzero, diff - Creation: full_like, empty, empty_like (empty / empty_like return zeros since MLX does not expose uninitialized memory) - Free functions: astype, matrix_transpose, cumulative_sum, cumulative_prod - Inspection: __array_namespace_info__ (capabilities, default_device, default_dtypes, devices, dtypes) Adds them to the ops docs and tests in test_ops.py / test_array.py. --- docs/src/python/ops.rst | 12 ++ python/src/array.cpp | 128 ++++++++++++ python/src/ops.cpp | 391 +++++++++++++++++++++++++++++++++++++ python/tests/test_array.py | 34 ++++ python/tests/test_ops.py | 93 +++++++++ 5 files changed, 658 insertions(+) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index de40c548a2..a7b40c948c 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -29,6 +29,7 @@ Operations array_equal asarray as_strided + astype atleast_1d atleast_2d atleast_3d @@ -59,19 +60,25 @@ Operations conv_general cos cosh + count_nonzero cummax cummin cumprod cumsum + cumulative_prod + cumulative_sum degrees depends dequantize diag diagonal + diff divide divmod einsum einsum_path + empty + empty_like equal erf erfinv @@ -83,6 +90,7 @@ Operations floor floor_divide full + full_like from_fp8 gather_mm gather_qmm @@ -116,8 +124,10 @@ Operations logical_not logical_and logical_or + logical_xor logsumexp matmul + matrix_transpose max maximum mean @@ -136,6 +146,7 @@ Operations partition pad permute_dims + positive power prod put_along_axis @@ -189,6 +200,7 @@ Operations tri tril triu + trunc unflatten var view diff --git a/python/src/array.cpp b/python/src/array.cpp index 28c12f622c..2bce77458e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -97,6 +97,53 @@ class ArrayPythonIterator { std::vector splits_; }; +// Returned by ``__array_namespace_info__()``; exposes array API inspection +// information about the namespace. +struct ArrayNamespaceInfo {}; + +namespace { + +// Whether ``dtype`` matches an array API dtype "kind" (a dtype, a kind string, +// or a tuple of those). A null/None kind matches everything. +bool dtype_matches_kind(const mx::Dtype& dtype, const nb::handle& kind) { + if (kind.is_none()) { + return true; + } + if (nb::isinstance(kind)) { + for (auto k : nb::cast(kind)) { + if (dtype_matches_kind(dtype, k)) { + return true; + } + } + return false; + } + if (nb::isinstance(kind)) { + return dtype == nb::cast(kind); + } + auto s = nb::cast(kind); + if (s == "bool") { + return dtype == mx::bool_; + } else if (s == "signed integer") { + return mx::issubdtype(dtype, mx::signedinteger); + } else if (s == "unsigned integer") { + return mx::issubdtype(dtype, mx::unsignedinteger); + } else if (s == "integral") { + return mx::issubdtype(dtype, mx::integer); + } else if (s == "real floating") { + return mx::issubdtype(dtype, mx::floating); + } else if (s == "complex floating") { + return mx::issubdtype(dtype, mx::complexfloating); + } else if (s == "numeric") { + return mx::issubdtype(dtype, mx::number); + } + std::ostringstream msg; + msg << "[__array_namespace_info__.dtypes] Unknown data type kind: '" << s + << "'."; + throw std::invalid_argument(msg.str()); +} + +} // namespace + void init_array(nb::module_& m) { // Types nb::class_( @@ -248,6 +295,87 @@ void init_array(nb::module_& m) { return os.str(); }); + nb::class_( + m, + "__array_namespace_info__", + R"pbdoc( + Array API namespace inspection utilities. + + Returned by ``array.__array_namespace__().__array_namespace_info__()``. + See the `array API `_ for + details. + )pbdoc") + .def(nb::init<>()) + .def( + "capabilities", + [](const ArrayNamespaceInfo&) { + nb::dict d; + d["boolean indexing"] = true; + d["data-dependent shapes"] = false; + d["max dimensions"] = nb::none(); + return d; + }, + R"pbdoc(The capabilities of the namespace.)pbdoc") + .def( + "default_device", + [](const ArrayNamespaceInfo&) { return mx::default_device(); }, + R"pbdoc(The default device.)pbdoc") + .def( + "default_dtypes", + [](const ArrayNamespaceInfo&, const nb::object&) { + nb::dict d; + d["real floating"] = nb::cast(mx::float32); + d["complex floating"] = nb::cast(mx::complex64); + d["integral"] = nb::cast(mx::int32); + d["indexing"] = nb::cast(mx::int32); + return d; + }, + "device"_a = nb::none(), + R"pbdoc(The default data types of the namespace.)pbdoc") + .def( + "devices", + [](const ArrayNamespaceInfo&) { + nb::list l; + l.append(mx::Device(mx::Device::cpu)); + if (mx::is_available(mx::Device(mx::Device::gpu))) { + l.append(mx::Device(mx::Device::gpu)); + } + return l; + }, + R"pbdoc(The devices supported by the namespace.)pbdoc") + .def( + "dtypes", + [](const ArrayNamespaceInfo&, + const nb::object&, + const nb::object& kind) { + const std::pair all[] = { + {"bool", mx::bool_}, + {"int8", mx::int8}, + {"int16", mx::int16}, + {"int32", mx::int32}, + {"int64", mx::int64}, + {"uint8", mx::uint8}, + {"uint16", mx::uint16}, + {"uint32", mx::uint32}, + {"uint64", mx::uint64}, + {"float16", mx::float16}, + {"bfloat16", mx::bfloat16}, + {"float32", mx::float32}, + {"float64", mx::float64}, + {"complex64", mx::complex64}, + }; + nb::dict d; + for (const auto& [name, dtype] : all) { + if (dtype_matches_kind(dtype, kind)) { + d[name] = nb::cast(dtype); + } + } + return d; + }, + "device"_a = nb::none(), + "kind"_a = nb::none(), + R"pbdoc(The data types supported by the namespace.)pbdoc"); + nb::class_( m, "ArrayAt", diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..02cc5ddc5e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5660,4 +5660,395 @@ void init_ops(nb::module_& m) { Returns: array: The array converted to fp8 with type ``uint8``. )pbdoc"); + // Array API elementwise and utility functions. + m.def( + "positive", + [](const mx::array& a, mx::StreamOrDevice s) { + return mx::astype(a, a.dtype(), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def positive(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise unary plus. Returns the input unchanged. + + Args: + a (array): Input array. + + Returns: + array: ``a`` unchanged. + )pbdoc"); + m.def( + "logical_xor", + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return mx::not_equal( + mx::astype(a, mx::bool_, s), mx::astype(b, mx::bool_, s), s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logical_xor(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise logical exclusive or. + + Args: + a (array): First input array or scalar. + b (array): Second input array or scalar. + + Returns: + array: The boolean array containing the logical xor of ``a`` and ``b``. + )pbdoc"); + m.def( + "trunc", + [](const ScalarOrArray& a_, mx::StreamOrDevice s) { + auto a = to_array(a_); + auto zero = mx::array(0, a.dtype()); + return mx::where( + mx::less(a, zero, s), mx::ceil(a, s), mx::floor(a, s), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def trunc(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise truncation towards zero. + + Args: + a (array): Input array. + + Returns: + array: The truncated array. + )pbdoc"); + m.def( + "count_nonzero", + [](const mx::array& a, + const IntOrVec& axis, + bool keepdims, + mx::StreamOrDevice s) { + auto nz = mx::astype( + mx::not_equal(a, mx::array(0, a.dtype()), s), mx::int32, s); + if (std::holds_alternative(axis)) { + return mx::sum(nz, keepdims, s); + } else if (auto pv = std::get_if(&axis); pv) { + return mx::sum(nz, *pv, keepdims, s); + } else { + return mx::sum(nz, std::get>(axis), keepdims, s); + } + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "keepdims"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def count_nonzero(a: array, /, *, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Count the number of non-zero elements along the given axis. + + Args: + a (array): Input array. + axis (int or tuple(int), optional): Axis or axes to count over. + Defaults to ``None`` in which case the whole array is counted. + keepdims (bool, optional): Keep the reduced axes as size one. + Default: ``False``. + + Returns: + array: The counts as an ``int32`` array. + )pbdoc"); + m.def( + "diff", + [](const mx::array& a, + int n, + int axis, + const std::optional& prepend, + const std::optional& append, + mx::StreamOrDevice s) { + int ndim = static_cast(a.ndim()); + int ax = axis < 0 ? axis + ndim : axis; + if (ax < 0 || ax >= ndim) { + throw std::invalid_argument( + "[diff] Axis is out of bounds for the array."); + } + if (n < 0) { + throw std::invalid_argument("[diff] Order `n` must be non-negative."); + } + mx::array x = a; + if (prepend || append) { + std::vector parts; + if (prepend) { + parts.push_back(*prepend); + } + parts.push_back(x); + if (append) { + parts.push_back(*append); + } + x = mx::concatenate(parts, ax, s); + } + for (int i = 0; i < n; ++i) { + mx::Shape upper_start(x.ndim(), 0); + mx::Shape lower_stop = x.shape(); + mx::Shape strides(x.ndim(), 1); + upper_start[ax] = 1; + lower_stop[ax] = x.shape(ax) - 1; + auto upper = mx::slice(x, upper_start, x.shape(), strides, s); + auto lower = + mx::slice(x, mx::Shape(x.ndim(), 0), lower_stop, strides, s); + x = mx::subtract(upper, lower, s); + } + return x; + }, + nb::arg(), + "n"_a = 1, + "axis"_a = -1, + nb::kw_only(), + "prepend"_a = nb::none(), + "append"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def diff(a: array, /, n: int = 1, axis: int = -1, *, prepend: Optional[array] = None, append: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + The n-th discrete difference along the given axis. + + Args: + a (array): Input array. + n (int, optional): The number of times to difference. Default: ``1``. + axis (int, optional): The axis along which to difference. + Default: ``-1``. + prepend (array, optional): Values to prepend along ``axis`` before + differencing. + append (array, optional): Values to append along ``axis`` before + differencing. + + Returns: + array: The n-th differences. + )pbdoc"); + // Array API creation functions. + m.def( + "full_like", + [](const mx::array& a, + const ScalarOrArray& vals, + std::optional dtype, + mx::StreamOrDevice s) { + auto t = dtype.value_or(a.dtype()); + return mx::full(a.shape(), to_array(vals, t), s); + }, + nb::arg(), + "vals"_a, + "dtype"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def full_like(a: array, vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + An array filled with ``vals`` with the same shape as the input. + + Args: + a (array): The input to take the shape from. + vals (float or int or array): Values to fill the array with. + dtype (Dtype, optional): Data type of the output array. If + unspecified the type of the input is used. + + Returns: + array: The output array. + )pbdoc"); + m.def( + "empty", + [](const nb::object& shape, + std::optional dtype, + mx::StreamOrDevice s) { + return mx::zeros(to_shape(shape), dtype.value_or(mx::float32), s); + }, + "shape"_a, + "dtype"_a.none() = mx::float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def empty(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Construct an uninitialized array. + + MLX does not expose uninitialized memory, so the contents are zero. + + Args: + shape (int or list(int)): The shape of the output array. + dtype (Dtype, optional): Data type of the output array. Default: + ``float32``. + + Returns: + array: The output array. + )pbdoc"); + m.def( + "empty_like", + [](const mx::array& a, + std::optional dtype, + mx::StreamOrDevice s) { + return mx::zeros(a.shape(), dtype.value_or(a.dtype()), s); + }, + nb::arg(), + "dtype"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def empty_like(a: array, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + An uninitialized array with the same shape as the input. + + MLX does not expose uninitialized memory, so the contents are zero. + + Args: + a (array): The input to take the shape and type from. + dtype (Dtype, optional): Data type of the output array. If + unspecified the type of the input is used. + + Returns: + array: The output array. + )pbdoc"); + // Array API free-function wrappers. + m.def( + "astype", + [](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) { + return mx::astype(a, dtype, s); + }, + nb::arg(), + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def astype(a: array, dtype: Dtype, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Cast the array to the given type. See also :meth:`array.astype`. + + Args: + a (array): Input array. + dtype (Dtype): The type to cast to. + + Returns: + array: The array cast to ``dtype``. + )pbdoc"); + m.def( + "matrix_transpose", + [](const mx::array& a, mx::StreamOrDevice s) { + if (a.ndim() < 2) { + throw std::invalid_argument( + "[matrix_transpose] Input must have at least 2 dimensions."); + } + return mx::swapaxes(a, -2, -1, s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def matrix_transpose(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Transpose the last two dimensions of an array. + + Args: + a (array): Input array with at least two dimensions. + + Returns: + array: The array with its last two dimensions transposed. + )pbdoc"); + m.def( + "cumulative_sum", + [](const mx::array& a, + std::optional axis, + std::optional dtype, + bool include_initial, + mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; + if (axis) { + ax = *axis; + } else { + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumsum(x, ax, false, true, s); + if (include_initial) { + int a2 = ax < 0 ? ax + out.ndim() : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + out = mx::concatenate( + {mx::zeros(init_shape, out.dtype(), s), out}, a2, s); + } + return out; + }, + nb::arg(), + nb::kw_only(), + "axis"_a = nb::none(), + "dtype"_a = nb::none(), + "include_initial"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def cumulative_sum(a: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + The cumulative sum of the elements along the given axis. + + Args: + a (array): Input array. + axis (int, optional): The axis to sum over. If unspecified the + cumulative sum of the flattened array is returned. + dtype (Dtype, optional): Cast the input to this type before summing. + include_initial (bool, optional): Prepend the identity (0) so the + output has one extra element along ``axis``. Default: ``False``. + + Returns: + array: The cumulative sum. + )pbdoc"); + m.def( + "cumulative_prod", + [](const mx::array& a, + std::optional axis, + std::optional dtype, + bool include_initial, + mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; + if (axis) { + ax = *axis; + } else { + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumprod(x, ax, false, true, s); + if (include_initial) { + int a2 = ax < 0 ? ax + out.ndim() : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + out = mx::concatenate( + {mx::ones(init_shape, out.dtype(), s), out}, a2, s); + } + return out; + }, + nb::arg(), + nb::kw_only(), + "axis"_a = nb::none(), + "dtype"_a = nb::none(), + "include_initial"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def cumulative_prod(a: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + The cumulative product of the elements along the given axis. + + Args: + a (array): Input array. + axis (int, optional): The axis to take the product over. If + unspecified the cumulative product of the flattened array is + returned. + dtype (Dtype, optional): Cast the input to this type first. + include_initial (bool, optional): Prepend the identity (1) so the + output has one extra element along ``axis``. Default: ``False``. + + Returns: + array: The cumulative product. + )pbdoc"); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 7905cd36c1..f9f4457249 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2159,6 +2159,40 @@ def test_array_namespace(self): self.assertTrue(hasattr(api, "array")) self.assertTrue(hasattr(api, "add")) + def test_array_namespace_info(self): + xp = mx.array(1.0).__array_namespace__() + info = xp.__array_namespace_info__() + + caps = info.capabilities() + self.assertIn("boolean indexing", caps) + self.assertIn("data-dependent shapes", caps) + self.assertIn("max dimensions", caps) + + self.assertEqual(info.default_device(), mx.default_device()) + + dd = info.default_dtypes() + self.assertEqual(dd["real floating"], mx.float32) + self.assertEqual(dd["complex floating"], mx.complex64) + self.assertEqual(dd["integral"], mx.int32) + self.assertEqual(dd["indexing"], mx.int32) + + devices = info.devices() + self.assertGreaterEqual(len(devices), 1) + self.assertIn(mx.default_device(), devices) + + all_dtypes = info.dtypes() + self.assertEqual(all_dtypes["float32"], mx.float32) + self.assertEqual(all_dtypes["bool"], mx.bool_) + + floats = info.dtypes(kind="real floating") + self.assertIn("float32", floats) + self.assertNotIn("int32", floats) + + ints = info.dtypes(kind=("signed integer", "unsigned integer")) + self.assertIn("int8", ints) + self.assertIn("uint8", ints) + self.assertNotIn("float32", ints) + def test_array_namespace_asarray(self): xp = mx.array(1.0).__array_namespace__() self.assertTrue(hasattr(xp, "asarray")) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 484ae47c3a..233d8ed6ce 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3407,6 +3407,99 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) + def test_array_api_elementwise(self): + a = mx.array([-1.5, -0.5, 0.0, 0.5, 2.7]) + self.assertEqual(mx.positive(a).tolist(), a.tolist()) + self.assertEqual(mx.trunc(a).tolist(), [-1.0, 0.0, 0.0, 0.0, 2.0]) + + x = mx.array([True, True, False, False]) + y = mx.array([True, False, True, False]) + self.assertEqual(mx.logical_xor(x, y).tolist(), [False, True, True, False]) + + c = mx.array([[0, 1, 0], [2, 3, 0]]) + self.assertEqual(mx.count_nonzero(c).item(), 3) + self.assertEqual(mx.count_nonzero(c, axis=0).tolist(), [1, 2, 0]) + self.assertEqual(mx.count_nonzero(c, axis=1).tolist(), [1, 2]) + self.assertEqual(mx.count_nonzero(c).dtype, mx.int32) + + def test_diff(self): + a = mx.array([1, 2, 4, 7, 0]) + self.assertEqual(mx.diff(a).tolist(), [1, 2, 3, -7]) + self.assertEqual(mx.diff(a, n=2).tolist(), [1, 1, -10]) + self.assertEqual(mx.diff(a, n=0).tolist(), a.tolist()) + + m = mx.array([[1, 3, 6], [0, 5, 6]]) + self.assertEqual(mx.diff(m, axis=0).tolist(), [[-1, 2, 0]]) + self.assertEqual(mx.diff(m, axis=1).tolist(), [[2, 3], [5, 1]]) + + # prepend / append. + self.assertEqual( + mx.diff(mx.array([2, 4, 7]), prepend=mx.array([0])).tolist(), + [2, 2, 3], + ) + self.assertEqual( + mx.diff(mx.array([2, 4, 7]), append=mx.array([10])).tolist(), + [2, 3, 3], + ) + + with self.assertRaises(ValueError): + mx.diff(a, axis=1) + + def test_array_api_creation(self): + a = mx.arange(6, dtype=mx.int16).reshape(2, 3) + + fl = mx.full_like(a, 7) + self.assertEqual(fl.shape, (2, 3)) + self.assertEqual(fl.dtype, mx.int16) + self.assertTrue(mx.all(fl == 7).item()) + self.assertEqual(mx.full_like(a, 1.5, dtype=mx.float32).dtype, mx.float32) + + e = mx.empty((2, 3)) + self.assertEqual(e.shape, (2, 3)) + self.assertEqual(e.dtype, mx.float32) + self.assertEqual(mx.empty((4,), dtype=mx.int32).dtype, mx.int32) + + el = mx.empty_like(a) + self.assertEqual(el.shape, (2, 3)) + self.assertEqual(el.dtype, mx.int16) + self.assertEqual(mx.empty_like(a, dtype=mx.float32).dtype, mx.float32) + + def test_astype_and_matrix_transpose(self): + a = mx.array([1, 2, 3], dtype=mx.int32) + self.assertEqual(mx.astype(a, mx.float32).dtype, mx.float32) + self.assertTrue(mx.array_equal(mx.astype(a, mx.float32), a.astype(mx.float32))) + + m = mx.arange(6).reshape(2, 3) + self.assertEqual(mx.matrix_transpose(m).shape, (3, 2)) + self.assertTrue(mx.array_equal(mx.matrix_transpose(m), mx.swapaxes(m, -2, -1))) + # Batched. + b = mx.arange(24).reshape(2, 3, 4) + self.assertEqual(mx.matrix_transpose(b).shape, (2, 4, 3)) + with self.assertRaises(ValueError): + mx.matrix_transpose(mx.array([1, 2, 3])) + + def test_cumulative_sum_prod(self): + a = mx.array([1, 2, 3, 4]) + self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10]) + self.assertEqual( + mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10] + ) + self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24]) + self.assertEqual( + mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24] + ) + + m = mx.array([[1, 2], [3, 4]]) + self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]]) + self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]]) + self.assertEqual( + mx.cumulative_sum(m, axis=1, include_initial=True).tolist(), + [[0, 1, 3], [0, 3, 7]], + ) + # axis=None flattens. + self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10]) + self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32) + if __name__ == "__main__": mlx_tests.MLXTestRunner()