diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 80608c1ff0..f682f03b96 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -221,85 +221,113 @@ array eval_impl(std::vector outputs, bool async) { } std::set open_streams; - while (!tape.empty()) { - auto arr = std::move(tape.back()); - tape.pop_back(); - - auto stream = arr.primitive().stream(); - open_streams.insert(stream); - - if (async) { - // Lookup corresponding event - auto e = events.find(stream.index); - if (e == events.end()) { - e = events.emplace(stream.index, Event{stream}).first; - } - e->second.set_value(1); - arr.attach_event(e->second); - for (auto& s : arr.siblings()) { - s.attach_event(e->second); + try { + while (!tape.empty()) { + auto arr = std::move(tape.back()); + tape.pop_back(); + + auto stream = arr.primitive().stream(); + open_streams.insert(stream); + + if (async) { + // Lookup corresponding event + auto e = events.find(stream.index); + if (e == events.end()) { + e = events.emplace(stream.index, Event{stream}).first; + } + e->second.set_value(1); + arr.attach_event(e->second); + for (auto& s : arr.siblings()) { + s.attach_event(e->second); + } } - } - for (auto& in : arr.inputs()) { - if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { - // Use fence to wait within a single eval - // Get the input array's stream fence and wait on the - // output arrays stream - fences[it->second.first].wait(stream, in); - } else if (in.event().valid()) { - if (in.event().is_signaled()) { - in.detach_event(); - } else if (in.event().stream() != stream) { - // Use event to wait across async eval - in.event().wait(stream); + for (auto& in : arr.inputs()) { + if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { + // Use fence to wait within a single eval + // Get the input array's stream fence and wait on the + // output arrays stream + fences[it->second.first].wait(stream, in); + } else if (in.event().valid()) { + if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { + // Use event to wait across async eval + in.event().wait(stream); + } } } - } - if (arr.primitive().device() == Device::gpu) { - gpu::eval(arr); - } else { - cpu::eval(arr); - } + if (arr.primitive().device() == Device::gpu) { + gpu::eval(arr); + } else { + cpu::eval(arr); + } - if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || - (get_active_memory() > get_memory_limit() && - scheduler::n_active_tasks() > 0)) { - // Commit any open streams - for (auto& s : open_streams) { - if (s.device == Device::gpu) { - gpu::finalize(s); + if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || + (get_active_memory() > get_memory_limit() && + scheduler::n_active_tasks() > 0)) { + // Commit any open streams + for (auto& s : open_streams) { + if (s.device == Device::gpu) { + gpu::finalize(s); + } } - } - scheduler::wait_for_one(); - while (get_active_memory() > get_memory_limit() && - scheduler::n_active_tasks() > 0) { scheduler::wait_for_one(); - } - } - - auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { - if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { - auto it = fences.find(stream.index); - if (it == fences.end()) { - it = fences.emplace(stream.index, Fence{stream}).first; + while (get_active_memory() > get_memory_limit() && + scheduler::n_active_tasks() > 0) { + scheduler::wait_for_one(); } - it->second.update(stream, a, nf->second.second); } - }; - arr.set_status(array::Status::evaluated); - // TODO Maybe always want the fence coherent kernel in the same cbuf - // as the other kernels? - maybe_update_fence(arr); - for (auto& sib : arr.siblings()) { - sib.set_status(array::Status::evaluated); - maybe_update_fence(sib); + auto maybe_update_fence = + [&fences, &needs_fence, stream](const array& a) { + if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { + auto it = fences.find(stream.index); + if (it == fences.end()) { + it = fences.emplace(stream.index, Fence{stream}).first; + } + it->second.update(stream, a, nf->second.second); + } + }; + + arr.set_status(array::Status::evaluated); + // TODO Maybe always want the fence coherent kernel in the same cbuf + // as the other kernels? + maybe_update_fence(arr); + for (auto& sib : arr.siblings()) { + sib.set_status(array::Status::evaluated); + maybe_update_fence(sib); + } + if (!arr.is_tracer()) { + arr.detach(); + } } - if (!arr.is_tracer()) { - arr.detach(); + } catch (...) { + // A primitive threw from inside its eval (e.g. argument validation in + // eval_gpu, or a JIT compile failure). Arrays evaluated earlier in this + // tape are already marked evaluated, but their kernels sit in pending + // command buffers that only the epilogue below would commit, and events + // attached during this eval would never be signaled. Left that way, a + // later read of an affected array returns an unwritten buffer or blocks + // forever. Signal the events and flush the touched streams, then let the + // exception propagate. + for (auto& [idx, e] : events) { + try { + auto es = e.stream(); + e.signal(es); + open_streams.insert(es); + } catch (...) { + } + } + for (auto& s : open_streams) { + try { + synchronize(s); + } catch (...) { + // Preserve the original exception. + } } + throw; } // Signal the event in its stream diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 5d6daaec21..da7f0ea8df 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -195,6 +195,38 @@ def test_multistream_deadlock(self): mx.eval(z) mx.set_memory_limit(old_limit) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_eval_exception_does_not_corrupt_state(self): + # An exception thrown from inside a primitive's eval (here a Metal + # compile error raised lazily at eval time) must not corrupt arrays + # evaluated earlier in the same batch: they are already marked + # evaluated, so their pending command buffers must still be + # committed before the exception propagates. + a = mx.full((1024,), 3.0) + b = a * 2.0 # encoded in the same eval batch as the failing kernel + + kernel = mx.fast.metal_kernel( + name="test_eval_exception_bad_kernel", + input_names=["inp"], + output_names=["out"], + source="this is not metal code {", + ) + with self.assertRaises(Exception): + (y,) = kernel( + inputs=[b], + output_shapes=[b.shape], + output_dtypes=[b.dtype], + grid=(1, 1, 1), + threadgroup=(1, 1, 1), + ) + mx.eval(y) + + self.assertTrue(mx.all(b == 6.0).item()) + + # Fresh computations after the failure stay correct. + x = mx.full((512,), 2.0) + self.assertEqual((x + 1.0).sum().item(), 512.0 * 3.0) + if __name__ == "__main__": mlx_tests.MLXTestRunner()