Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 94 additions & 66 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,85 +221,113 @@ array eval_impl(std::vector<array> outputs, bool async) {
}

std::set<Stream> open_streams;
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();

auto stream = arr.primitive().stream();
open_streams.insert(stream);

if (async) {
// Lookup corresponding event
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
try {
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();

auto stream = arr.primitive().stream();
open_streams.insert(stream);

if (async) {
// Lookup corresponding event
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
}
}

for (auto& in : arr.inputs()) {
if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
// Use fence to wait within a single eval
// Get the input array's stream fence and wait on the
// output arrays stream
fences[it->second.first].wait(stream, in);
} else if (in.event().valid()) {
if (in.event().is_signaled()) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval
in.event().wait(stream);
for (auto& in : arr.inputs()) {
if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
// Use fence to wait within a single eval
// Get the input array's stream fence and wait on the
// output arrays stream
fences[it->second.first].wait(stream, in);
} else if (in.event().valid()) {
if (in.event().is_signaled()) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval
in.event().wait(stream);
}
}
}
}

if (arr.primitive().device() == Device::gpu) {
gpu::eval(arr);
} else {
cpu::eval(arr);
}
if (arr.primitive().device() == Device::gpu) {
gpu::eval(arr);
} else {
cpu::eval(arr);
}

if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
(get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& s : open_streams) {
if (s.device == Device::gpu) {
gpu::finalize(s);
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
(get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& s : open_streams) {
if (s.device == Device::gpu) {
gpu::finalize(s);
}
}
}
scheduler::wait_for_one();
while (get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
}
}

auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
auto it = fences.find(stream.index);
if (it == fences.end()) {
it = fences.emplace(stream.index, Fence{stream}).first;
while (get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
}
it->second.update(stream, a, nf->second.second);
}
};

arr.set_status(array::Status::evaluated);
// TODO Maybe always want the fence coherent kernel in the same cbuf
// as the other kernels?
maybe_update_fence(arr);
for (auto& sib : arr.siblings()) {
sib.set_status(array::Status::evaluated);
maybe_update_fence(sib);
auto maybe_update_fence =
[&fences, &needs_fence, stream](const array& a) {
if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
auto it = fences.find(stream.index);
if (it == fences.end()) {
it = fences.emplace(stream.index, Fence{stream}).first;
}
it->second.update(stream, a, nf->second.second);
}
};

arr.set_status(array::Status::evaluated);
// TODO Maybe always want the fence coherent kernel in the same cbuf
// as the other kernels?
maybe_update_fence(arr);
for (auto& sib : arr.siblings()) {
sib.set_status(array::Status::evaluated);
maybe_update_fence(sib);
}
if (!arr.is_tracer()) {
arr.detach();
}
}
if (!arr.is_tracer()) {
arr.detach();
} catch (...) {
// A primitive threw from inside its eval (e.g. argument validation in
// eval_gpu, or a JIT compile failure). Arrays evaluated earlier in this
// tape are already marked evaluated, but their kernels sit in pending
// command buffers that only the epilogue below would commit, and events
// attached during this eval would never be signaled. Left that way, a
// later read of an affected array returns an unwritten buffer or blocks
// forever. Signal the events and flush the touched streams, then let the
// exception propagate.
for (auto& [idx, e] : events) {
try {
auto es = e.stream();
e.signal(es);
open_streams.insert(es);
} catch (...) {
}
}
for (auto& s : open_streams) {
try {
synchronize(s);
} catch (...) {
// Preserve the original exception.
}
}
throw;
}

// Signal the event in its stream
Expand Down
32 changes: 32 additions & 0 deletions python/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,38 @@ def test_multistream_deadlock(self):
mx.eval(z)
mx.set_memory_limit(old_limit)

@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_eval_exception_does_not_corrupt_state(self):
# An exception thrown from inside a primitive's eval (here a Metal
# compile error raised lazily at eval time) must not corrupt arrays
# evaluated earlier in the same batch: they are already marked
# evaluated, so their pending command buffers must still be
# committed before the exception propagates.
a = mx.full((1024,), 3.0)
b = a * 2.0 # encoded in the same eval batch as the failing kernel

kernel = mx.fast.metal_kernel(
name="test_eval_exception_bad_kernel",
input_names=["inp"],
output_names=["out"],
source="this is not metal code {",
)
with self.assertRaises(Exception):
(y,) = kernel(
inputs=[b],
output_shapes=[b.shape],
output_dtypes=[b.dtype],
grid=(1, 1, 1),
threadgroup=(1, 1, 1),
)
mx.eval(y)

self.assertTrue(mx.all(b == 6.0).item())

# Fresh computations after the failure stay correct.
x = mx.full((512,), 2.0)
self.assertEqual((x + 1.0).sum().item(), 512.0 * 3.0)


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
Loading