diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index f9ff22677a..86a1505e13 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -70,12 +70,16 @@ void copy_general_general( dynamic_i_offset ? dynamic_i_offset->data() : nullptr; auto o_offset_ptr = dynamic_o_offset ? dynamic_o_offset->data() : nullptr; - auto size = src.size(); if (data_shape.empty()) { auto val = static_cast(*src_ptr); *dst_ptr = val; return; } + auto size = std::accumulate( + data_shape.begin(), + data_shape.end(), + int64_t{1}, + std::multiplies()); auto [shape, strides] = collapse_contiguous_dims(data_shape, {i_strides, o_strides}); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 660ea76c8d..6cd0c49497 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3209,6 +3209,12 @@ def test_dynamic_slicing(self): out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1)) self.assertTrue(mx.array_equal(expected, out)) + with mx.stream(mx.cpu): + x = mx.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8) + expected = x[1:3, 2:4, 3:5, 4:6] + out = mx.slice(x, mx.array([1, 2, 3, 4]), (0, 1, 2, 3), (2, 2, 2, 2)) + self.assertTrue(mx.array_equal(expected, out)) + x = mx.zeros(shape=(4, 4, 4)) update = mx.random.randint(0, 100, shape=(3, 2, 1)) out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2)) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 81b10578cb..c2650e19a7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -482,6 +482,29 @@ TEST_CASE("test dynamic slice") { out = slice(src, array({1, 1}), {0, 1}, {1, 2}); expected = array({4, 5}, {1, 2}); CHECK(array_equal(out, expected).item()); + + src = reshape(arange(5 * 6 * 7 * 8, Device::cpu), {5, 6, 7, 8}, Device::cpu); + out = + slice(src, array({1, 2, 3, 4}), {0, 1, 2, 3}, {2, 2, 2, 2}, Device::cpu); + expected = array( + {476, + 477, + 484, + 485, + 532, + 533, + 540, + 541, + 812, + 813, + 820, + 821, + 868, + 869, + 876, + 877}, + {2, 2, 2, 2}); + CHECK(array_equal(out, expected).item()); } TEST_CASE("test dynamic slice update") {