diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 30ee8e383d..76adf8afc8 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -29,6 +29,7 @@ Operations array_equal asarray as_strided + astype atleast_1d atleast_2d atleast_3d diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b3d010dda3..4bd8a24446 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3399,6 +3399,25 @@ void init_ops(nb::module_& m) { Returns: array: The output array which is the strided view of the input. )pbdoc"); + m.def( + "astype", + &mx::astype, + nb::arg(), + "dtype"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def astype(a: array, dtype: Dtype, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Cast the array to a specified type. + + Args: + a (array): Input array. + dtype (Dtype): Type to which the array is cast. + + Returns: + array: The array with type ``dtype``. + )pbdoc"); m.def( "cumsum", [](const mx::array& a, @@ -5849,7 +5868,10 @@ void init_ops(nb::module_& m) { m.attr("atan") = m.attr("arctan"); m.attr("atanh") = m.attr("arctanh"); m.attr("atan2") = m.attr("arctan2"); - m.attr("pow") = m.attr("power"); m.attr("bitwise_left_shift") = m.attr("left_shift"); m.attr("bitwise_right_shift") = m.attr("right_shift"); + m.attr("empty") = m.attr("zeros"); + m.attr("empty_like") = m.attr("zeros_like"); + m.attr("matrix_transpose") = m.attr("transpose"); + m.attr("pow") = m.attr("power"); }