Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/python/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Operations
cummin
cumprod
cumsum
cumulative_prod
cumulative_sum
degrees
depends
dequantize
Expand Down
80 changes: 62 additions & 18 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3424,32 +3424,52 @@ void init_ops(nb::module_& m) {
std::optional<int> axis,
bool reverse,
bool inclusive,
std::optional<mx::Dtype> dtype,
bool include_initial,
mx::StreamOrDevice s) {
mx::array x = dtype ? mx::astype(a, *dtype, s) : a;
int ax;
if (axis) {
return mx::cumsum(a, *axis, reverse, inclusive, s);
ax = *axis;
} else {
return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
x = mx::reshape(x, {-1}, s);
ax = 0;
}
auto out = mx::cumsum(x, ax, reverse, inclusive, s);
if (include_initial) {
int a2 = ax < 0 ? ax + static_cast<int>(out.ndim()) : ax;
mx::Shape init_shape = out.shape();
init_shape[a2] = 1;
auto init = mx::zeros(init_shape, out.dtype(), s);
out = reverse
? mx::concatenate({out, init}, a2, s)
: mx::concatenate({init, out}, a2, s);
}
return out;
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"dtype"_a = nb::none(),
"include_initial"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative sum of the elements along the given axis.

Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative sum
over. If unspecified the cumulative sum of the flattened array is
returned.
reverse (bool): Perform the cumulative sum in reverse.
a (array): Input array.
axis (int, optional): Axis to compute over. If unspecified the
cumulative sum of the flattened array is returned.
reverse (bool): Perform the cumulative sum in reverse. Default: ``False``.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
element of the input. Default: ``True``.
dtype (Dtype, optional): Cast the input to this type before summing.
include_initial (bool): Prepend the identity element (0) so the
output has one extra element along ``axis``. Default: ``False``.

Returns:
array: The output array.
Expand All @@ -3460,32 +3480,52 @@ void init_ops(nb::module_& m) {
std::optional<int> axis,
bool reverse,
bool inclusive,
std::optional<mx::Dtype> dtype,
bool include_initial,
mx::StreamOrDevice s) {
mx::array x = dtype ? mx::astype(a, *dtype, s) : a;
int ax;
if (axis) {
return mx::cumprod(a, *axis, reverse, inclusive, s);
ax = *axis;
} else {
return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
x = mx::reshape(x, {-1}, s);
ax = 0;
}
auto out = mx::cumprod(x, ax, reverse, inclusive, s);
if (include_initial) {
int a2 = ax < 0 ? ax + static_cast<int>(out.ndim()) : ax;
mx::Shape init_shape = out.shape();
init_shape[a2] = 1;
auto init = mx::ones(init_shape, out.dtype(), s);
out = reverse
? mx::concatenate({out, init}, a2, s)
: mx::concatenate({init, out}, a2, s);
}
return out;
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"dtype"_a = nb::none(),
"include_initial"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative product of the elements along the given axis.

Args:
a (array): Input array
axis (int, optional): Optional axis to compute the cumulative product
over. If unspecified the cumulative product of the flattened array is
returned.
reverse (bool): Perform the cumulative product in reverse.
a (array): Input array.
axis (int, optional): Axis to compute over. If unspecified the
cumulative product of the flattened array is returned.
reverse (bool): Perform the cumulative product in reverse. Default: ``False``.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
element of the input. Default: ``True``.
dtype (Dtype, optional): Cast the input to this type before multiplying.
include_initial (bool): Prepend the identity element (1) so the
output has one extra element along ``axis``. Default: ``False``.

Returns:
array: The output array.
Expand Down Expand Up @@ -5874,4 +5914,8 @@ void init_ops(nb::module_& m) {
m.attr("empty_like") = m.attr("zeros_like");
m.attr("matrix_transpose") = m.attr("transpose");
m.attr("pow") = m.attr("power");
// Array API aliases — cumulative_sum/cumulative_prod are pure aliases of
// cumsum/cumprod, which now support dtype and include_initial.
m.attr("cumulative_sum") = m.attr("cumsum");
m.attr("cumulative_prod") = m.attr("cumprod");
}
22 changes: 22 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3446,6 +3446,28 @@ def test_to_from_fp8(self):
self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals))
self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals))

def test_cumulative_sum_prod(self):
a = mx.array([1, 2, 3, 4])
self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10])
self.assertEqual(
mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10]
)
self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24])
self.assertEqual(
mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24]
)

m = mx.array([[1, 2], [3, 4]])
self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]])
self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]])
self.assertEqual(
mx.cumulative_sum(m, axis=1, include_initial=True).tolist(),
[[0, 1, 3], [0, 3, 7]],
)
# axis=None flattens.
self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10])
self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32)


if __name__ == "__main__":
mlx_tests.MLXTestRunner()