From ac40f529c08170a22ec76cf77633c8dded8f5376 Mon Sep 17 00:00:00 2001 From: Noah Lyons Date: Fri, 19 Jun 2026 00:37:37 -0400 Subject: [PATCH] Fix compiled kernel correctness for negative-strided inputs The compiled (fused) kernel path produced wrong results when inputs had negative strides (e.g. x[::-1]). Two issues: 1. compiled_check_contiguity used the broad contiguous flag for single inputs, which is true for negative-strided arrays (no data gaps). Changed to require row_contiguous or col_contiguous, matching the multi-input path. 2. Metal/CUDA strided compiled kernels used unsigned index arithmetic (elem_to_loc_1), wrapping negative strides. Force int64_t indices when any input has negative strides. Also generate the _large (int64_t) strided kernel variant for ndim=1. The CPU compiled path uses signed pointer arithmetic and only needed the contiguity check fix. Fixes #3716. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx/backend/common/compiled.cpp | 18 ++++++++++--- mlx/backend/common/compiled.h | 5 +++- mlx/backend/cuda/compiled.cpp | 5 ++-- mlx/backend/metal/compiled.cpp | 36 ++++++++++++------------- python/tests/test_compile.py | 47 +++++++++++++++++++++++++++++++++ 5 files changed, 87 insertions(+), 24 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index aceeb1f7fd..a5c9934bd8 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -102,7 +102,7 @@ bool compiled_check_contiguity( } if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) { contiguous = false; - } else if (non_scalar_inputs == 1 && !all_contig) { + } else if (non_scalar_inputs == 1 && !(all_row_contig || all_col_contig)) { contiguous = false; } else if (non_scalar_inputs == 0 && !shape.empty()) { contiguous = false; @@ -224,7 +224,8 @@ std::tuple> compiled_collapse_contiguous_dims( bool compiled_use_large_index( const std::vector& inputs, const std::vector& outputs, - bool contiguous) { + bool contiguous, + const std::vector& strides) { if (contiguous) { size_t max_size = 0; for (const auto& in : inputs) { @@ -236,7 +237,18 @@ bool compiled_use_large_index( for (const auto& o : outputs) { max_size = std::max(max_size, o.size()); } - return max_size > UINT32_MAX; + if (max_size > UINT32_MAX) { + return true; + } + // Check for negative strides in inputs (strides[0] is the output). + for (size_t i = 1; i < strides.size(); ++i) { + for (auto v : strides[i]) { + if (v < 0) { + return true; + } + } + } + return false; } } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index 84a3460459..7a01642e77 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -87,9 +87,12 @@ std::tuple> compiled_collapse_contiguous_dims( const std::function& is_constant); // Return whether the kernel should use large index. +// Also returns true when any non-contiguous input has negative strides, +// since unsigned index arithmetic wraps negative stride values. bool compiled_use_large_index( const std::vector& inputs, const std::vector& outputs, - bool contiguous); + bool contiguous, + const std::vector& strides = {}); } // namespace mlx::core diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 4ffb959ac2..1389038951 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -288,8 +288,9 @@ void Compiled::eval_gpu( auto [contiguous, shape, strides_vec] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Whether to use large index. - bool large = compiled_use_large_index(inputs, outputs, contiguous); + // Whether to use large index (also true for negative strides). + bool large = + compiled_use_large_index(inputs, outputs, contiguous, strides_vec); cu::KernelArgs args; // Put inputs. diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index cdb0a471be..6076136455 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -139,8 +139,8 @@ inline void build_kernel( os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { int offset = i * ndim; - os += - fmt::format("elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); + os += fmt::format( + "elem_to_loc_1<{0}>(pos.x, in_strides[{1}]);\n", idx_type, offset); } else if (ndim == 2) { int offset = i * ndim; os += fmt::format( @@ -316,20 +316,20 @@ void Compiled::eval_gpu( /* dynamic_dims = */ false, /* use_big_index = */ false, /* work_per_thread = */ i > 3 ? 2 : 1); - if (i > 1) { - build_kernel( - kernel, - kernel_lib_ + "_strided_" + std::to_string(i) + "_large", - inputs_, - outputs_, - tape_, - is_constant_, - /* contiguous = */ false, - /* ndim = */ i, - /* dynamic_dims = */ false, - /* use_big_index = */ true, - /* work_per_thread = */ i > 3 ? 4 : 1); - } + // Generate int64_t index variant for all ndim, including ndim=1. + // Negative strides force large mode even for small arrays. + build_kernel( + kernel, + kernel_lib_ + "_strided_" + std::to_string(i) + "_large", + inputs_, + outputs_, + tape_, + is_constant_, + /* contiguous = */ false, + /* ndim = */ i, + /* dynamic_dims = */ false, + /* use_big_index = */ true, + /* work_per_thread = */ i > 3 ? 4 : 1); } build_kernel( kernel, @@ -363,8 +363,8 @@ void Compiled::eval_gpu( auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Whether to use large index. - bool large = compiled_use_large_index(inputs, outputs, contiguous); + // Whether to use large index (also true for negative strides). + bool large = compiled_use_large_index(inputs, outputs, contiguous, strides); // Get the kernel from the lib int ndim = shape.size(); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 632b34119a..65dd2b9dcd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1415,6 +1415,53 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) + def test_compile_negative_strides(self): + # 1D negative stride with elementwise expression + @mx.compile + def f(x): + return 2.0 * x[::-1] + + x = mx.arange(8, dtype=mx.float32) + expected = 2.0 * x[::-1] + self.assertTrue(mx.array_equal(f(x), expected)) + + # 1D negative stride with slice update + def g_eager(x): + base = mx.zeros_like(x) + base[::-1] += 2.0 * x[::-1] + return base + + g_compiled = mx.compile(g_eager) + expected = g_eager(x) + self.assertTrue(mx.array_equal(g_compiled(x), expected)) + + # 2D negative stride + @mx.compile + def h(x): + return x[::-1] + 1.0 + + y = mx.arange(12, dtype=mx.float32).reshape(3, 4) + expected = y[::-1] + 1.0 + self.assertTrue(mx.array_equal(h(y), expected)) + + # Mixed positive and negative strides + @mx.compile + def m(x): + return x[::-1, ::2] * 3.0 + + z = mx.arange(24, dtype=mx.float32).reshape(4, 6) + expected = z[::-1, ::2] * 3.0 + self.assertTrue(mx.array_equal(m(z), expected)) + + # 4D negative stride (exercises work_per_thread > 1 path) + @mx.compile + def p(x): + return x + 1.0 + + w = mx.arange(120, dtype=mx.float32).reshape(2, 3, 4, 5) + expected = w[::-1, :, ::-1, :] + 1.0 + self.assertTrue(mx.array_equal(p(w[::-1, :, ::-1, :]), expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner()