diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 12d7397be4..c58f39ac46 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -13,6 +13,7 @@ #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" #include "mlx/graph_utils.h" +#include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" @@ -1035,6 +1036,22 @@ std::vector compile_replace( } auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); }; + auto has_negative_strides = [](const array& a) { + return std::any_of( + a.strides().begin(), a.strides().end(), [](auto s) { return s < 0; }); + }; + auto is_negative_strided_slice = [](const array& a) { + if (!a.has_primitive()) { + return false; + } + const auto& prim = a.primitive(); + if (typeid(prim) != typeid(Slice)) { + return false; + } + const auto& strides = std::get<2>(static_cast(prim).state()); + return std::any_of( + strides.begin(), strides.end(), [](auto s) { return s < 0; }); + }; for (auto& a : tape) { // Arrays in the tape without primitives are either: @@ -1050,6 +1067,14 @@ std::vector compile_replace( for (auto& in : a.inputs()) { real_inputs.push_back(trace_to_real.at(in.id())); } + const auto& prim = a.primitive(); + if (typeid(prim) == typeid(Compiled)) { + for (auto& in : real_inputs) { + if (has_negative_strides(in) || is_negative_strided_slice(in)) { + in = contiguous(in, false, a.primitive().stream()); + } + } + } if (a.siblings().empty()) { auto shape = shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape(); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 632b34119a..2de818fdbe 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -71,6 +71,47 @@ def test_compile_nonfinite_constants(self): self.assertEqual(out[0].item(), 1.0) self.assertEqual(out[1].item(), float("-inf")) + def test_compile_negative_strided_slice_update_expr(self): + x = mx.arange(6, dtype=mx.float32) + expected = mx.array([0, 2, 4, 6, 8, 10], dtype=mx.float32) + + def add_update(x): + base = mx.zeros_like(x) + base[::-1] += 2.0 * x[::-1] + return base + + def set_update(x): + base = mx.zeros_like(x) + base[::-1] = 2.0 * x[::-1] + return base + + def bare_view_update(x): + base = mx.zeros_like(x) + base[::-1] += x[::-1] + return base + + def positive_strided_update(x): + base = mx.zeros_like(x) + base[::2] += 2.0 * x[::2] + return base + + self.assertTrue(mx.array_equal(mx.compile(lambda x: x[::-1] + 0)(x), x[::-1])) + self.assertTrue( + mx.array_equal(mx.compile(lambda x: x[::-1] * 2)(x), 2 * x[::-1]) + ) + + for fn in (add_update, set_update): + self.assertTrue(mx.array_equal(fn(x), expected)) + self.assertTrue(mx.array_equal(mx.compile(fn)(x), expected)) + + self.assertTrue(mx.array_equal(mx.compile(bare_view_update)(x), x)) + self.assertTrue( + mx.array_equal( + mx.compile(positive_strided_update)(x), + mx.array([0, 0, 4, 0, 8, 0], dtype=mx.float32), + ) + ) + def test_compile_tuple_output_in_thread(self): @mx.compile def fun(x):