diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 269efc034b..1769cdf2f8 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -15,13 +15,21 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); array in = inputs[0]; - // Make sure no identity reductions trickle down here. assert(!axes_.empty()); - assert(out.size() != in.size()); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); + // When all the reduced axes have size 1 at runtime, which can happen with + // shapeless compilation, the reduction is the identity so just cast-copy + // the input to the output. + if (out.size() == in.size()) { + CopyType ctype = + in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + return; + } + if (in.size() == 0) { init_reduce(encoder, in, out, reduce_type_); return; diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 644af5d218..562f966297 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -952,9 +952,17 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); array in = inputs[0]; - // Make sure no identity reductions trickle down here assert(!axes_.empty()); - assert(out.size() != in.size()); + + // When all the reduced axes have size 1 at runtime, which can happen with + // shapeless compilation, the reduction is the identity so just cast-copy + // the input to the output. + if (out.size() == in.size()) { + CopyType ctype = + in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, stream()); + return; + } // Continue with reduction operation // Minimum of 4 bytes since we use size 4 structs for all reduce diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..462c6ece78 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -51,6 +51,12 @@ std::tuple, bool> compute_reduce_shape( is_noop &= (out_shape.back() == shape[i]); } std::vector sorted_axes(axes_set.begin(), axes_set.end()); + // Dimensions that are size 1 at trace time can be dynamic when compiling + // shapeless, so the reduction cannot be elided from the graph. Reductions + // over no axes stay no-ops since they are shape independent. + if (is_noop && !sorted_axes.empty() && detail::in_dynamic_tracing()) { + is_noop = false; + } return {out_shape, sorted_axes, is_noop}; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 632b34119a..4b7643dadd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1110,6 +1110,45 @@ def fun(args): self.assertEqual(out[0].shape, (3, 1, 4, 2)) self.assertEqual(out[1].shape, (2, 2, 5)) + def test_shapeless_compile_reduce_after_gather(self): + # Reductions over dimensions that happen to have size 1 at trace time + # used to be elided from the graph, so replays with larger dynamic + # shapes returned stale values (issue #3201) + buf = mx.array([10.0, 20.0, 30.0, 40.0, 50.0]) + reductions = [ + mx.sum, + mx.mean, + mx.prod, + mx.min, + mx.max, + mx.all, + mx.any, + mx.argmin, + mx.argmax, + ] + for reduction in reductions: + + def fun(buf, idx): + return reduction(mx.take(buf, idx, axis=0)) + + # Trace with a size-1 reduction and replay with larger sizes + cfun = mx.compile(fun, shapeless=True) + for n in [1, 2, 3, 4]: + idx = mx.arange(n) + self.assertTrue( + mx.array_equal(cfun(buf, idx), fun(buf, idx)), + f"{reduction.__name__} failed for n={n}", + ) + + # Replay with size 1 so the reduction is an identity at runtime + cfun = mx.compile(fun, shapeless=True) + for n in [2, 1]: + idx = mx.arange(n) + self.assertTrue( + mx.array_equal(cfun(buf, idx), fun(buf, idx)), + f"{reduction.__name__} failed for replay with n={n}", + ) + def test_leaks(self): gc.collect() if mx.metal.is_available():