From e4c46c613c2305e53effc55e5ea94ea562a9dd3d Mon Sep 17 00:00:00 2001 From: ran Date: Fri, 12 Jun 2026 14:39:54 -0500 Subject: [PATCH 1/2] Contain exceptions thrown from inside eval An exception escaping a primitive's eval_cpu/eval_gpu (e.g. argument validation or a lazy Metal JIT compile error) unwinds eval_impl past the per-stream epilogue, with two consequences: 1. Arrays evaluated earlier in the same batch are already marked Status::evaluated, but their kernels sit in a command buffer that is never committed. A later read returns unwritten memory (stale buffer pool bytes). 2. Events attached during the failed eval (async_eval, synchronizer) are never signaled, so a later read of an affected array blocks forever in array::wait. Both reproduce on an unmodified build: a bystander array in the failed eval batch reads back wrong (sync), or a later read deadlocks on an unsignaled event (async). Fix: on exception, signal every event created for this eval and synchronize every touched stream (committing pending command buffers and draining CPU queues) before rethrowing. The exception still propagates to the caller; it just no longer poisons unrelated state. Signed-off-by: ran --- mlx/transforms.cpp | 160 ++++++++++++++++++++++---------------- python/tests/test_eval.py | 32 ++++++++ 2 files changed, 126 insertions(+), 66 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 80608c1ff0..f682f03b96 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -221,85 +221,113 @@ array eval_impl(std::vector outputs, bool async) { } std::set open_streams; - while (!tape.empty()) { - auto arr = std::move(tape.back()); - tape.pop_back(); - - auto stream = arr.primitive().stream(); - open_streams.insert(stream); - - if (async) { - // Lookup corresponding event - auto e = events.find(stream.index); - if (e == events.end()) { - e = events.emplace(stream.index, Event{stream}).first; - } - e->second.set_value(1); - arr.attach_event(e->second); - for (auto& s : arr.siblings()) { - s.attach_event(e->second); + try { + while (!tape.empty()) { + auto arr = std::move(tape.back()); + tape.pop_back(); + + auto stream = arr.primitive().stream(); + open_streams.insert(stream); + + if (async) { + // Lookup corresponding event + auto e = events.find(stream.index); + if (e == events.end()) { + e = events.emplace(stream.index, Event{stream}).first; + } + e->second.set_value(1); + arr.attach_event(e->second); + for (auto& s : arr.siblings()) { + s.attach_event(e->second); + } } - } - for (auto& in : arr.inputs()) { - if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { - // Use fence to wait within a single eval - // Get the input array's stream fence and wait on the - // output arrays stream - fences[it->second.first].wait(stream, in); - } else if (in.event().valid()) { - if (in.event().is_signaled()) { - in.detach_event(); - } else if (in.event().stream() != stream) { - // Use event to wait across async eval - in.event().wait(stream); + for (auto& in : arr.inputs()) { + if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { + // Use fence to wait within a single eval + // Get the input array's stream fence and wait on the + // output arrays stream + fences[it->second.first].wait(stream, in); + } else if (in.event().valid()) { + if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { + // Use event to wait across async eval + in.event().wait(stream); + } } } - } - if (arr.primitive().device() == Device::gpu) { - gpu::eval(arr); - } else { - cpu::eval(arr); - } + if (arr.primitive().device() == Device::gpu) { + gpu::eval(arr); + } else { + cpu::eval(arr); + } - if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || - (get_active_memory() > get_memory_limit() && - scheduler::n_active_tasks() > 0)) { - // Commit any open streams - for (auto& s : open_streams) { - if (s.device == Device::gpu) { - gpu::finalize(s); + if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || + (get_active_memory() > get_memory_limit() && + scheduler::n_active_tasks() > 0)) { + // Commit any open streams + for (auto& s : open_streams) { + if (s.device == Device::gpu) { + gpu::finalize(s); + } } - } - scheduler::wait_for_one(); - while (get_active_memory() > get_memory_limit() && - scheduler::n_active_tasks() > 0) { scheduler::wait_for_one(); - } - } - - auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { - if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { - auto it = fences.find(stream.index); - if (it == fences.end()) { - it = fences.emplace(stream.index, Fence{stream}).first; + while (get_active_memory() > get_memory_limit() && + scheduler::n_active_tasks() > 0) { + scheduler::wait_for_one(); } - it->second.update(stream, a, nf->second.second); } - }; - arr.set_status(array::Status::evaluated); - // TODO Maybe always want the fence coherent kernel in the same cbuf - // as the other kernels? - maybe_update_fence(arr); - for (auto& sib : arr.siblings()) { - sib.set_status(array::Status::evaluated); - maybe_update_fence(sib); + auto maybe_update_fence = + [&fences, &needs_fence, stream](const array& a) { + if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { + auto it = fences.find(stream.index); + if (it == fences.end()) { + it = fences.emplace(stream.index, Fence{stream}).first; + } + it->second.update(stream, a, nf->second.second); + } + }; + + arr.set_status(array::Status::evaluated); + // TODO Maybe always want the fence coherent kernel in the same cbuf + // as the other kernels? + maybe_update_fence(arr); + for (auto& sib : arr.siblings()) { + sib.set_status(array::Status::evaluated); + maybe_update_fence(sib); + } + if (!arr.is_tracer()) { + arr.detach(); + } } - if (!arr.is_tracer()) { - arr.detach(); + } catch (...) { + // A primitive threw from inside its eval (e.g. argument validation in + // eval_gpu, or a JIT compile failure). Arrays evaluated earlier in this + // tape are already marked evaluated, but their kernels sit in pending + // command buffers that only the epilogue below would commit, and events + // attached during this eval would never be signaled. Left that way, a + // later read of an affected array returns an unwritten buffer or blocks + // forever. Signal the events and flush the touched streams, then let the + // exception propagate. + for (auto& [idx, e] : events) { + try { + auto es = e.stream(); + e.signal(es); + open_streams.insert(es); + } catch (...) { + } + } + for (auto& s : open_streams) { + try { + synchronize(s); + } catch (...) { + // Preserve the original exception. + } } + throw; } // Signal the event in its stream diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 5d6daaec21..7f24dbd18e 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -195,6 +195,38 @@ def test_multistream_deadlock(self): mx.eval(z) mx.set_memory_limit(old_limit) + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") + def test_eval_exception_does_not_corrupt_state(self): + # An exception thrown from inside a primitive's eval (here a Metal + # compile error raised lazily at eval time) must not corrupt arrays + # evaluated earlier in the same batch: they are already marked + # evaluated, so their pending command buffers must still be + # committed before the exception propagates. + a = mx.full((1024,), 3.0) + b = a * 2.0 # encoded in the same eval batch as the failing kernel + + kernel = mx.fast.metal_kernel( + name="test_eval_exception_bad_kernel", + input_names=["inp"], + output_names=["out"], + source="this is not metal code {", + ) + with self.assertRaises(Exception): + (y,) = kernel( + inputs=[b], + output_shapes=[b.shape], + output_dtypes=[b.dtype], + grid=(1, 1, 1), + threadgroup=(1, 1, 1), + ) + mx.eval(y) + + self.assertTrue(mx.all(b == 6.0).item()) + + # Fresh computations after the failure stay correct. + x = mx.full((512,), 2.0) + self.assertEqual((x + 1.0).sum().item(), 512.0 * 3.0) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 3d077bff6ad0140d1289dffb84c0cb9e2e13dade Mon Sep 17 00:00:00 2001 From: Ranran Date: Sat, 13 Jun 2026 19:46:30 -0500 Subject: [PATCH 2/2] Update python/tests/test_eval.py Co-authored-by: Cheng --- python/tests/test_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 7f24dbd18e..da7f0ea8df 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -195,7 +195,7 @@ def test_multistream_deadlock(self): mx.eval(z) mx.set_memory_limit(old_limit) - @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_eval_exception_does_not_corrupt_state(self): # An exception thrown from inside a primitive's eval (here a Metal # compile error raised lazily at eval time) must not corrupt arrays