From c654130065d40f0a8084bf38778820c559eb9c54 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Tue, 9 Jun 2026 10:34:40 -0400 Subject: [PATCH 1/2] Fix CPU gather transposing col-contiguous slices (#1) The gather "fast copy" path used the row/col contiguous flags to decide whether a per-index slice could be copied as a single contiguous block. For a column-contiguous source the slice is contiguous in memory but in column-major order, while the output is written in row-major order, so a multi-dimensional slice came out transposed. This surfaced on the CPU backend via chained `take` through size-1 axes (which yield a col-contiguous intermediate) and via a direct `take` from a transposed source; the GPU backend was correct. Replace the flag-based heuristic with a direct check that the slice is row-major contiguous within the source (each non-singleton slice dim's source stride equals the product of the inner slice sizes), which is the exact precondition for the contiguous copy. Falls back to the strided iterator otherwise. Co-Authored-By: Claude Opus 4.8 (1M context) --- mlx/backend/cpu/indexing.cpp | 55 +++++++++++++----------------------- tests/ops_tests.cpp | 36 +++++++++++++++++++++++ 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index ec4090172f..9101ae284e 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -66,43 +66,26 @@ void gather( array& out, const std::vector& axes, const Shape& slice_sizes) { - // If the array is row contiguous then we can do a contiguous copy given - // two conditions on the slice size: - // - Any number of leading ones in the slice sizes are allowed - // - All other slice sizes match the corresponding dimension except the - // first non-singleton slice size - // If the array is col contiguous then the reverse is the case: - // - Any number of trailing ones in the slice sizes are allowed - // - All other slice sizes match the corresponding dimension except the - // first non-singleton slice size from the end - - bool can_copy = false; - if (src.flags().row_contiguous) { - can_copy = true; - - // Ignore leading 1s - int i = 0; - for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) - ; - - // Check the remaining - i++; - for (; i < src.ndim() && can_copy; ++i) { - can_copy = (src.shape(i) == slice_sizes[i]); - } - } else if (src.flags().col_contiguous) { - can_copy = true; - - // Ignore trailing 1s - int i = slice_sizes.size() - 1; - for (; i >= 0 && slice_sizes[i] == 1; --i) - ; - - // Skip the next slice size and check the remaining - i--; - for (; i >= 0 && can_copy; --i) { - can_copy = (src.shape(i) == slice_sizes[i]); + // Each gathered slice is written into the (row-contiguous) output as a + // sequential block. We can therefore replace the per-element strided read + // with a single contiguous copy only when the slice is itself laid out + // contiguously in row-major order within the source. That holds exactly when + // every non-singleton slice dimension has a source stride equal to the + // product of the slice sizes of the dimensions inside it. Size-1 dimensions + // are skipped since their stride is irrelevant. + // + // Checking the strides directly (rather than relying on the row/col + // contiguous flags) is important: a column-contiguous source is contiguous + // in memory but in column-major order, so copying a multi-dimensional slice + // as a raw block would transpose it. + bool can_copy = true; + int64_t expected_stride = 1; + for (int i = src.ndim() - 1; i >= 0 && can_copy; --i) { + if (slice_sizes[i] == 1) { + continue; } + can_copy = (src.strides()[i] == expected_stride); + expected_stride *= slice_sizes[i]; } size_t slice_size = 1; for (auto s : slice_sizes) { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2230f77cf..3f8f7b390a 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2264,6 +2264,42 @@ TEST_CASE("test take") { CHECK_THROWS(take(a, zeros({2, 3, 2}), 0)); } +TEST_CASE("test gather contiguity") { + // Regression test for a CPU-backend bug where the gather "fast copy" path + // copied a multi-dimensional slice from a column-contiguous source as a raw + // (column-major) memory block, producing a transposed/wrong-stride result. + // The bug only showed up on the CPU backend and is exercised by: + // - chained takes through a size-1 axis (which produce a col-contiguous + // intermediate), and + // - a direct take from a transposed (col-contiguous) source. + + // Chained gather through size-1 axes (issue repro). + { + auto u = reshape(array({1.0f, 2.0f}), {2, 1, 1}); + auto g = take(u, array({0, 1}, int32), 0, Device::cpu); + g = take(g, array({0, 0, 0}, int32), 1, Device::cpu); + g = take(g, array({0, 0, 0}, int32), 2, Device::cpu); + CHECK_EQ(g.shape(), Shape{2, 3, 3}); + // Each batch must be uniform: batch 0 -> 1.0, batch 1 -> 2.0. + auto expected = array( + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}, + {2, 3, 3}); + CHECK(array_equal(g, expected).item()); + } + + // Direct take from a column-contiguous source with a multi-dim slice. + { + auto base = astype(reshape(arange(24), {4, 3, 2}), int32); + auto a = transpose(base, {2, 1, 0}); // [2, 3, 4], col-contiguous + auto t = take(a, array({0, 1}, int32), 2, Device::cpu); + CHECK_EQ(t.shape(), Shape{2, 3, 2}); + auto expected = array( + {0, 6, 2, 8, 4, 10, 1, 7, 3, 9, 5, 11}, {2, 3, 2}, int32); + CHECK(array_equal(t, expected).item()); + } +} + TEST_CASE("test take along axis") { // No zero dim arrays auto a = array(1); From 8e20a28fe65570556c8d80441d84bac58b3f6ce2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 11 Jun 2026 08:31:14 +0900 Subject: [PATCH 2/2] Run lint --- mlx/backend/cpu/indexing.cpp | 10 +--------- tests/ops_tests.cpp | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 9101ae284e..d668b56adb 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -69,15 +69,7 @@ void gather( // Each gathered slice is written into the (row-contiguous) output as a // sequential block. We can therefore replace the per-element strided read // with a single contiguous copy only when the slice is itself laid out - // contiguously in row-major order within the source. That holds exactly when - // every non-singleton slice dimension has a source stride equal to the - // product of the slice sizes of the dimensions inside it. Size-1 dimensions - // are skipped since their stride is irrelevant. - // - // Checking the strides directly (rather than relying on the row/col - // contiguous flags) is important: a column-contiguous source is contiguous - // in memory but in column-major order, so copying a multi-dimensional slice - // as a raw block would transpose it. + // contiguously in row-major order within the source. bool can_copy = true; int64_t expected_stride = 1; for (int i = src.ndim() - 1; i >= 0 && can_copy; --i) { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 3f8f7b390a..0fcf912f93 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2282,8 +2282,24 @@ TEST_CASE("test gather contiguity") { CHECK_EQ(g.shape(), Shape{2, 3, 3}); // Each batch must be uniform: batch 0 -> 1.0, batch 1 -> 2.0. auto expected = array( - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}, + {1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f, + 2.0f}, {2, 3, 3}); CHECK(array_equal(g, expected).item()); } @@ -2294,8 +2310,8 @@ TEST_CASE("test gather contiguity") { auto a = transpose(base, {2, 1, 0}); // [2, 3, 4], col-contiguous auto t = take(a, array({0, 1}, int32), 2, Device::cpu); CHECK_EQ(t.shape(), Shape{2, 3, 2}); - auto expected = array( - {0, 6, 2, 8, 4, 10, 1, 7, 3, 9, 5, 11}, {2, 3, 2}, int32); + auto expected = + array({0, 6, 2, 8, 4, 10, 1, 7, 3, 9, 5, 11}, {2, 3, 2}, int32); CHECK(array_equal(t, expected).item()); } }