Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mlx/backend/cuda/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
array in = inputs[0];

// Make sure no identity reductions trickle down here.
assert(!axes_.empty());
assert(out.size() != in.size());

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

// When all the reduced axes have size 1 at runtime, which can happen with
// shapeless compilation, the reduction is the identity so just cast-copy
// the input to the output.
if (out.size() == in.size()) {
CopyType ctype =
in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
return;
}

if (in.size() == 0) {
init_reduce(encoder, in, out, reduce_type_);
return;
Expand Down
12 changes: 10 additions & 2 deletions mlx/backend/metal/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,17 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
array in = inputs[0];

// Make sure no identity reductions trickle down here
assert(!axes_.empty());
assert(out.size() != in.size());

// When all the reduced axes have size 1 at runtime, which can happen with
// shapeless compilation, the reduction is the identity so just cast-copy
// the input to the output.
if (out.size() == in.size()) {
CopyType ctype =
in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, stream());
return;
}

// Continue with reduction operation
// Minimum of 4 bytes since we use size 4 structs for all reduce
Expand Down
6 changes: 6 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
// Dimensions that are size 1 at trace time can be dynamic when compiling
// shapeless, so the reduction cannot be elided from the graph. Reductions
// over no axes stay no-ops since they are shape independent.
if (is_noop && !sorted_axes.empty() && detail::in_dynamic_tracing()) {
is_noop = false;
}
return {out_shape, sorted_axes, is_noop};
}

Expand Down
39 changes: 39 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,45 @@ def fun(args):
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))

def test_shapeless_compile_reduce_after_gather(self):
# Reductions over dimensions that happen to have size 1 at trace time
# used to be elided from the graph, so replays with larger dynamic
# shapes returned stale values (issue #3201)
buf = mx.array([10.0, 20.0, 30.0, 40.0, 50.0])
reductions = [
mx.sum,
mx.mean,
mx.prod,
mx.min,
mx.max,
mx.all,
mx.any,
mx.argmin,
mx.argmax,
]
for reduction in reductions:

def fun(buf, idx):
return reduction(mx.take(buf, idx, axis=0))

# Trace with a size-1 reduction and replay with larger sizes
cfun = mx.compile(fun, shapeless=True)
for n in [1, 2, 3, 4]:
idx = mx.arange(n)
self.assertTrue(
mx.array_equal(cfun(buf, idx), fun(buf, idx)),
f"{reduction.__name__} failed for n={n}",
)

# Replay with size 1 so the reduction is an identity at runtime
cfun = mx.compile(fun, shapeless=True)
for n in [2, 1]:
idx = mx.arange(n)
self.assertTrue(
mx.array_equal(cfun(buf, idx), fun(buf, idx)),
f"{reduction.__name__} failed for replay with n={n}",
)

def test_leaks(self):
gc.collect()
if mx.metal.is_available():
Expand Down
Loading