diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index ec4090172f..d668b56adb 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -66,43 +66,18 @@ void gather( array& out, const std::vector& axes, const Shape& slice_sizes) { - // If the array is row contiguous then we can do a contiguous copy given - // two conditions on the slice size: - // - Any number of leading ones in the slice sizes are allowed - // - All other slice sizes match the corresponding dimension except the - // first non-singleton slice size - // If the array is col contiguous then the reverse is the case: - // - Any number of trailing ones in the slice sizes are allowed - // - All other slice sizes match the corresponding dimension except the - // first non-singleton slice size from the end - - bool can_copy = false; - if (src.flags().row_contiguous) { - can_copy = true; - - // Ignore leading 1s - int i = 0; - for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) - ; - - // Check the remaining - i++; - for (; i < src.ndim() && can_copy; ++i) { - can_copy = (src.shape(i) == slice_sizes[i]); - } - } else if (src.flags().col_contiguous) { - can_copy = true; - - // Ignore trailing 1s - int i = slice_sizes.size() - 1; - for (; i >= 0 && slice_sizes[i] == 1; --i) - ; - - // Skip the next slice size and check the remaining - i--; - for (; i >= 0 && can_copy; --i) { - can_copy = (src.shape(i) == slice_sizes[i]); + // Each gathered slice is written into the (row-contiguous) output as a + // sequential block. We can therefore replace the per-element strided read + // with a single contiguous copy only when the slice is itself laid out + // contiguously in row-major order within the source. + bool can_copy = true; + int64_t expected_stride = 1; + for (int i = src.ndim() - 1; i >= 0 && can_copy; --i) { + if (slice_sizes[i] == 1) { + continue; } + can_copy = (src.strides()[i] == expected_stride); + expected_stride *= slice_sizes[i]; } size_t slice_size = 1; for (auto s : slice_sizes) { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2230f77cf..0fcf912f93 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2264,6 +2264,58 @@ TEST_CASE("test take") { CHECK_THROWS(take(a, zeros({2, 3, 2}), 0)); } +TEST_CASE("test gather contiguity") { + // Regression test for a CPU-backend bug where the gather "fast copy" path + // copied a multi-dimensional slice from a column-contiguous source as a raw + // (column-major) memory block, producing a transposed/wrong-stride result. + // The bug only showed up on the CPU backend and is exercised by: + // - chained takes through a size-1 axis (which produce a col-contiguous + // intermediate), and + // - a direct take from a transposed (col-contiguous) source. + + // Chained gather through size-1 axes (issue repro). + { + auto u = reshape(array({1.0f, 2.0f}), {2, 1, 1}); + auto g = take(u, array({0, 1}, int32), 0, Device::cpu); + g = take(g, array({0, 0, 0}, int32), 1, Device::cpu); + g = take(g, array({0, 0, 0}, int32), 2, Device::cpu); + CHECK_EQ(g.shape(), Shape{2, 3, 3}); + // Each batch must be uniform: batch 0 -> 1.0, batch 1 -> 2.0. + auto expected = array( + {1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f}, + {2, 3, 3}); + CHECK(array_equal(g, expected).item()); + } + + // Direct take from a column-contiguous source with a multi-dim slice. + { + auto base = astype(reshape(arange(24), {4, 3, 2}), int32); + auto a = transpose(base, {2, 1, 0}); // [2, 3, 4], col-contiguous + auto t = take(a, array({0, 1}, int32), 2, Device::cpu); + CHECK_EQ(t.shape(), Shape{2, 3, 2}); + auto expected = + array({0, 6, 2, 8, 4, 10, 1, 7, 3, 9, 5, 11}, {2, 3, 2}, int32); + CHECK(array_equal(t, expected).item()); + } +} + TEST_CASE("test take along axis") { // No zero dim arrays auto a = array(1);