From 2a554bc9ffb74f9bae0d5fa9acac3db96dd63e64 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Fri, 19 Jun 2026 20:24:16 -0400 Subject: [PATCH 1/2] feat(array-api): add empty, empty_like, astype, matrix_transpose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit empty/empty_like are pure aliases of zeros/zeros_like via m.attr("empty") = m.attr("zeros"), matching the pattern from #3678. MLX does not expose uninitialized memory so zeros is the correct semantic match. astype exposes mx::astype as a free function (Array API §2.0). matrix_transpose transposes the last two dimensions and validates ndim >= 2. Docs and tests included. Part of the array API split from #3684. --- docs/src/python/ops.rst | 4 ++++ python/src/ops.cpp | 50 ++++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 27 ++++++++++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 30ee8e383d..5db1e0e89f 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -29,6 +29,7 @@ Operations array_equal asarray as_strided + astype atleast_1d atleast_2d atleast_3d @@ -73,6 +74,8 @@ Operations divmod einsum einsum_path + empty + empty_like equal erf erfinv @@ -121,6 +124,7 @@ Operations logical_or logsumexp matmul + matrix_transpose max maximum mean diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b3d010dda3..007f026434 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5852,4 +5852,54 @@ void init_ops(nb::module_& m) { m.attr("pow") = m.attr("power"); m.attr("bitwise_left_shift") = m.attr("left_shift"); m.attr("bitwise_right_shift") = m.attr("right_shift"); + // Array API: empty / empty_like — pure aliases of zeros / zeros_like. + // MLX does not expose uninitialized memory, so zeros are a correct + // semantic match. + m.attr("empty") = m.attr("zeros"); + m.attr("empty_like") = m.attr("zeros_like"); + // Array API free-function wrappers. + m.def( + "astype", + [](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) { + return mx::astype(a, dtype, s); + }, + nb::arg(), + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def astype(a: array, dtype: Dtype, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Cast the array to the given type. See also :meth:`array.astype`. + + Args: + a (array): Input array. + dtype (Dtype): The type to cast to. + + Returns: + array: The array cast to ``dtype``. + )pbdoc"); + m.def( + "matrix_transpose", + [](const mx::array& a, mx::StreamOrDevice s) { + if (a.ndim() < 2) { + throw std::invalid_argument( + "[matrix_transpose] Input must have at least 2 dimensions."); + } + return mx::swapaxes(a, -2, -1, s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def matrix_transpose(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Transpose the last two dimensions of an array. + + Args: + a (array): Input array with at least two dimensions. + + Returns: + array: The array with its last two dimensions transposed. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 660ea76c8d..ca9749b135 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3446,6 +3446,33 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) + def test_array_api_creation(self): + a = mx.arange(6, dtype=mx.int16).reshape(2, 3) + + e = mx.empty((2, 3)) + self.assertEqual(e.shape, (2, 3)) + self.assertEqual(e.dtype, mx.float32) + self.assertEqual(mx.empty((4,), dtype=mx.int32).dtype, mx.int32) + + el = mx.empty_like(a) + self.assertEqual(el.shape, (2, 3)) + self.assertEqual(el.dtype, mx.int16) + self.assertEqual(mx.empty_like(a, dtype=mx.float32).dtype, mx.float32) + + def test_astype_and_matrix_transpose(self): + a = mx.array([1, 2, 3], dtype=mx.int32) + self.assertEqual(mx.astype(a, mx.float32).dtype, mx.float32) + self.assertTrue(mx.array_equal(mx.astype(a, mx.float32), a.astype(mx.float32))) + + m = mx.arange(6).reshape(2, 3) + self.assertEqual(mx.matrix_transpose(m).shape, (3, 2)) + self.assertTrue(mx.array_equal(mx.matrix_transpose(m), mx.swapaxes(m, -2, -1))) + # Batched. + b = mx.arange(24).reshape(2, 3, 4) + self.assertEqual(mx.matrix_transpose(b).shape, (2, 4, 3)) + with self.assertRaises(ValueError): + mx.matrix_transpose(mx.array([1, 2, 3])) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From b4c72abb325c142363850814469850200519a997 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 21 Jun 2026 18:02:04 +0900 Subject: [PATCH 2/2] remove garbage --- docs/src/python/ops.rst | 3 -- python/src/ops.cpp | 70 ++++++++++++---------------------------- python/tests/test_ops.py | 27 ---------------- 3 files changed, 21 insertions(+), 79 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 5db1e0e89f..76adf8afc8 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -74,8 +74,6 @@ Operations divmod einsum einsum_path - empty - empty_like equal erf erfinv @@ -124,7 +122,6 @@ Operations logical_or logsumexp matmul - matrix_transpose max maximum mean diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 007f026434..4bd8a24446 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3399,6 +3399,25 @@ void init_ops(nb::module_& m) { Returns: array: The output array which is the strided view of the input. )pbdoc"); + m.def( + "astype", + &mx::astype, + nb::arg(), + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def astype(a: array, dtype: Dtype, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Cast the array to a specified type. + + Args: + a (array): Input array. + dtype (Dtype): Type to which the array is cast. + + Returns: + array: The array with type ``dtype``. + )pbdoc"); m.def( "cumsum", [](const mx::array& a, @@ -5849,57 +5868,10 @@ void init_ops(nb::module_& m) { m.attr("atan") = m.attr("arctan"); m.attr("atanh") = m.attr("arctanh"); m.attr("atan2") = m.attr("arctan2"); - m.attr("pow") = m.attr("power"); m.attr("bitwise_left_shift") = m.attr("left_shift"); m.attr("bitwise_right_shift") = m.attr("right_shift"); - // Array API: empty / empty_like — pure aliases of zeros / zeros_like. - // MLX does not expose uninitialized memory, so zeros are a correct - // semantic match. m.attr("empty") = m.attr("zeros"); m.attr("empty_like") = m.attr("zeros_like"); - // Array API free-function wrappers. - m.def( - "astype", - [](const mx::array& a, mx::Dtype dtype, mx::StreamOrDevice s) { - return mx::astype(a, dtype, s); - }, - nb::arg(), - "dtype"_a, - nb::kw_only(), - "stream"_a = nb::none(), - nb::sig( - "def astype(a: array, dtype: Dtype, /, *, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Cast the array to the given type. See also :meth:`array.astype`. - - Args: - a (array): Input array. - dtype (Dtype): The type to cast to. - - Returns: - array: The array cast to ``dtype``. - )pbdoc"); - m.def( - "matrix_transpose", - [](const mx::array& a, mx::StreamOrDevice s) { - if (a.ndim() < 2) { - throw std::invalid_argument( - "[matrix_transpose] Input must have at least 2 dimensions."); - } - return mx::swapaxes(a, -2, -1, s); - }, - nb::arg(), - nb::kw_only(), - "stream"_a = nb::none(), - nb::sig( - "def matrix_transpose(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Transpose the last two dimensions of an array. - - Args: - a (array): Input array with at least two dimensions. - - Returns: - array: The array with its last two dimensions transposed. - )pbdoc"); + m.attr("matrix_transpose") = m.attr("transpose"); + m.attr("pow") = m.attr("power"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index ca9749b135..660ea76c8d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3446,33 +3446,6 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) - def test_array_api_creation(self): - a = mx.arange(6, dtype=mx.int16).reshape(2, 3) - - e = mx.empty((2, 3)) - self.assertEqual(e.shape, (2, 3)) - self.assertEqual(e.dtype, mx.float32) - self.assertEqual(mx.empty((4,), dtype=mx.int32).dtype, mx.int32) - - el = mx.empty_like(a) - self.assertEqual(el.shape, (2, 3)) - self.assertEqual(el.dtype, mx.int16) - self.assertEqual(mx.empty_like(a, dtype=mx.float32).dtype, mx.float32) - - def test_astype_and_matrix_transpose(self): - a = mx.array([1, 2, 3], dtype=mx.int32) - self.assertEqual(mx.astype(a, mx.float32).dtype, mx.float32) - self.assertTrue(mx.array_equal(mx.astype(a, mx.float32), a.astype(mx.float32))) - - m = mx.arange(6).reshape(2, 3) - self.assertEqual(mx.matrix_transpose(m).shape, (3, 2)) - self.assertTrue(mx.array_equal(mx.matrix_transpose(m), mx.swapaxes(m, -2, -1))) - # Batched. - b = mx.arange(24).reshape(2, 3, 4) - self.assertEqual(mx.matrix_transpose(b).shape, (2, 4, 3)) - with self.assertRaises(ValueError): - mx.matrix_transpose(mx.array([1, 2, 3])) - if __name__ == "__main__": mlx_tests.MLXTestRunner()