Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 11 additions & 36 deletions mlx/backend/cpu/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,43 +66,18 @@ void gather(
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size
// If the array is col contiguous then the reverse is the case:
// - Any number of trailing ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size from the end

bool can_copy = false;
if (src.flags().row_contiguous) {
can_copy = true;

// Ignore leading 1s
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;

// Check the remaining
i++;
for (; i < src.ndim() && can_copy; ++i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
} else if (src.flags().col_contiguous) {
can_copy = true;

// Ignore trailing 1s
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;

// Skip the next slice size and check the remaining
i--;
for (; i >= 0 && can_copy; --i) {
can_copy = (src.shape(i) == slice_sizes[i]);
// Each gathered slice is written into the (row-contiguous) output as a
// sequential block. We can therefore replace the per-element strided read
// with a single contiguous copy only when the slice is itself laid out
// contiguously in row-major order within the source.
bool can_copy = true;
int64_t expected_stride = 1;
for (int i = src.ndim() - 1; i >= 0 && can_copy; --i) {
if (slice_sizes[i] == 1) {
continue;
}
can_copy = (src.strides()[i] == expected_stride);
expected_stride *= slice_sizes[i];
}
size_t slice_size = 1;
for (auto s : slice_sizes) {
Expand Down
52 changes: 52 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2264,6 +2264,58 @@ TEST_CASE("test take") {
CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
}

TEST_CASE("test gather contiguity") {
// Regression test for a CPU-backend bug where the gather "fast copy" path
// copied a multi-dimensional slice from a column-contiguous source as a raw
// (column-major) memory block, producing a transposed/wrong-stride result.
// The bug only showed up on the CPU backend and is exercised by:
// - chained takes through a size-1 axis (which produce a col-contiguous
// intermediate), and
// - a direct take from a transposed (col-contiguous) source.

// Chained gather through size-1 axes (issue repro).
{
auto u = reshape(array({1.0f, 2.0f}), {2, 1, 1});
auto g = take(u, array({0, 1}, int32), 0, Device::cpu);
g = take(g, array({0, 0, 0}, int32), 1, Device::cpu);
g = take(g, array({0, 0, 0}, int32), 2, Device::cpu);
CHECK_EQ(g.shape(), Shape{2, 3, 3});
// Each batch must be uniform: batch 0 -> 1.0, batch 1 -> 2.0.
auto expected = array(
{1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
2.0f,
2.0f,
2.0f,
2.0f,
2.0f,
2.0f,
2.0f,
2.0f,
2.0f},
{2, 3, 3});
CHECK(array_equal(g, expected).item<bool>());
}

// Direct take from a column-contiguous source with a multi-dim slice.
{
auto base = astype(reshape(arange(24), {4, 3, 2}), int32);
auto a = transpose(base, {2, 1, 0}); // [2, 3, 4], col-contiguous
auto t = take(a, array({0, 1}, int32), 2, Device::cpu);
CHECK_EQ(t.shape(), Shape{2, 3, 2});
auto expected =
array({0, 6, 2, 8, 4, 10, 1, 7, 3, 9, 5, 11}, {2, 3, 2}, int32);
CHECK(array_equal(t, expected).item<bool>());
}
}

TEST_CASE("test take along axis") {
// No zero dim arrays
auto a = array(1);
Expand Down
Loading