diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 76adf8afc8..4217fbe392 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -65,6 +65,8 @@ Operations cummin cumprod cumsum + cumulative_prod + cumulative_sum degrees depends dequantize diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4bd8a24446..ae443b73ed 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3424,32 +3424,52 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumsum(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumsum(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::zeros(init_shape, out.dtype(), s); + out = reverse + ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: - a (array): Input array - axis (int, optional): Optional axis to compute the cumulative sum - over. If unspecified the cumulative sum of the flattened array is - returned. - reverse (bool): Perform the cumulative sum in reverse. + a (array): Input array. + axis (int, optional): Axis to compute over. If unspecified the + cumulative sum of the flattened array is returned. + reverse (bool): Perform the cumulative sum in reverse. Default: ``False``. inclusive (bool): The i-th element of the output includes the i-th - element of the input. + element of the input. Default: ``True``. + dtype (Dtype, optional): Cast the input to this type before summing. + include_initial (bool): Prepend the identity element (0) so the + output has one extra element along ``axis``. Default: ``False``. Returns: array: The output array. @@ -3460,32 +3480,52 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumprod(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumprod(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::ones(init_shape, out.dtype(), s); + out = reverse + ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. Args: - a (array): Input array - axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array is - returned. - reverse (bool): Perform the cumulative product in reverse. + a (array): Input array. + axis (int, optional): Axis to compute over. If unspecified the + cumulative product of the flattened array is returned. + reverse (bool): Perform the cumulative product in reverse. Default: ``False``. inclusive (bool): The i-th element of the output includes the i-th - element of the input. + element of the input. Default: ``True``. + dtype (Dtype, optional): Cast the input to this type before multiplying. + include_initial (bool): Prepend the identity element (1) so the + output has one extra element along ``axis``. Default: ``False``. Returns: array: The output array. @@ -5874,4 +5914,8 @@ void init_ops(nb::module_& m) { m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose"); m.attr("pow") = m.attr("power"); + // Array API aliases — cumulative_sum/cumulative_prod are pure aliases of + // cumsum/cumprod, which now support dtype and include_initial. + m.attr("cumulative_sum") = m.attr("cumsum"); + m.attr("cumulative_prod") = m.attr("cumprod"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 660ea76c8d..3dcc1efce1 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3446,6 +3446,28 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) + def test_cumulative_sum_prod(self): + a = mx.array([1, 2, 3, 4]) + self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10]) + self.assertEqual( + mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10] + ) + self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24]) + self.assertEqual( + mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24] + ) + + m = mx.array([[1, 2], [3, 4]]) + self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]]) + self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]]) + self.assertEqual( + mx.cumulative_sum(m, axis=1, include_initial=True).tolist(), + [[0, 1, 3], [0, 3, 7]], + ) + # axis=None flattens. + self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10]) + self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32) + if __name__ == "__main__": mlx_tests.MLXTestRunner()