From 0ce0380ef38101ce40ed528662128e164ce5548c Mon Sep 17 00:00:00 2001 From: Alessio Pollero Date: Sun, 21 Jun 2026 15:02:08 +0400 Subject: [PATCH] Propagate CPU load errors through events Route CPU read and task failures through eval events so waiters observe errors consistently instead of losing exceptions in scheduler tasks. Skip dependent CPU work once the active event is poisoned, while keeping event signals and fence updates on an unchecked dispatch path so cross-stream waiters are still released. Add sync, async, cross-stream, same-stream, and pending-dependent load regressions for the propagated error path. --- mlx/backend/common/load.cpp | 5 +- mlx/backend/cpu/encoder.h | 64 ++++++++++++++-- mlx/backend/cpu/eval.cpp | 27 +++++++ mlx/backend/cpu/eval.h | 3 + mlx/backend/cuda/event.cu | 83 +++++++++++++++++--- mlx/backend/cuda/fence.cpp | 46 ++++++++++- mlx/backend/metal/event.cpp | 83 +++++++++++++++++--- mlx/backend/metal/event.h | 8 ++ mlx/backend/metal/fence.cpp | 63 ++++++++++++--- mlx/backend/no_gpu/event.cpp | 44 +++++++++-- mlx/backend/no_gpu/fence.cpp | 40 ++++++---- mlx/event.h | 7 ++ mlx/transforms.cpp | 43 ++++++++++- python/tests/test_load.py | 29 +++++++ tests/load_tests.cpp | 144 +++++++++++++++++++++++++++++++++++ tests/scheduler_tests.cpp | 43 +++++++++++ 16 files changed, 666 insertions(+), 66 deletions(-) diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index ce41963de7..2f3fe8755b 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -3,8 +3,8 @@ #include #include +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" -#include "mlx/scheduler.h" namespace { @@ -51,7 +51,8 @@ void Load::eval_cpu(const std::vector& inputs, array& out) { } }; auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); - scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); }); + cpu::get_command_encoder(stream()).dispatch( + [fut = std::move(fut)]() { fut.get(); }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index cd015623f6..df8e25647b 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -2,9 +2,11 @@ #pragma once +#include #include #include "mlx/array.h" +#include "mlx/event.h" #include "mlx/scheduler.h" namespace mlx::core::cpu { @@ -36,28 +38,80 @@ struct MLX_API CommandEncoder { std::make_move_iterator(arrays.end())); } + void set_error_event(Event event) { + event_ = std::move(event); + } + std::vector& temporaries() { return temporaries_; } template void dispatch(F&& f, Args&&... args) { + dispatch_impl(true, std::forward(f), std::forward(args)...); + } + + template + void dispatch_unchecked(F&& f, Args&&... args) { + dispatch_impl(false, std::forward(f), std::forward(args)...); + } + + private: + template + void dispatch_impl(bool skip_on_error, F&& f, Args&&... args) { num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; auto task = std::bind(std::forward(f), std::forward(args)...); + auto event = event_; if (num_ops_ == 0) { scheduler::notify_new_task(stream_); - auto task_wrap = [s = stream_, task = std::move(task)]() mutable { - task(); - scheduler::notify_task_completion(s); + auto task_wrap = [s = stream_, + event = std::move(event), + skip_on_error, + task = std::move(task)]() mutable { + struct CompletionNotifier { + Stream stream; + ~CompletionNotifier() { + scheduler::notify_task_completion(stream); + } + } completion{s}; + if (skip_on_error && event.valid() && event.error()) { + return; + } + try { + task(); + } catch (...) { + if (event.valid()) { + event.set_error(std::current_exception()); + } else { + throw; + } + } }; scheduler::enqueue(stream_, std::move(task_wrap)); } else { - scheduler::enqueue(stream_, std::move(task)); + scheduler::enqueue( + stream_, + [event = std::move(event), + skip_on_error, + task = std::move(task)]() mutable { + if (skip_on_error && event.valid() && event.error()) { + return; + } + try { + task(); + } catch (...) { + if (event.valid()) { + event.set_error(std::current_exception()); + } else { + throw; + } + } + }); } } - private: Stream stream_; + Event event_; std::vector temporaries_; int num_ops_{0}; }; diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 354820f0fe..2212305f5e 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -7,6 +7,33 @@ namespace mlx::core::cpu { +void set_error_event(Stream s, Event event) { + get_command_encoder(s).set_error_event(std::move(event)); +} + +void clear_error_event(Stream s) { + auto& encoders = get_command_encoders(); + auto it = encoders.find(s.index); + if (it != encoders.end()) { + it->second.set_error_event(Event{}); + return; + } + + auto& global_encoders = get_global_command_encoders(); + it = global_encoders.find(s.index); + if (it != global_encoders.end()) { + it->second.set_error_event(Event{}); + } +} + +void check_error_event(Stream s, Event event) { + get_command_encoder(s).dispatch([event = std::move(event)]() { + if (auto error = event.error()) { + std::rethrow_exception(error); + } + }); +} + void new_stream(Stream s) { auto& encoders = get_command_encoders(); encoders.try_emplace(s.index, s); diff --git a/mlx/backend/cpu/eval.h b/mlx/backend/cpu/eval.h index 775f3e46f0..7bd75a87d1 100644 --- a/mlx/backend/cpu/eval.h +++ b/mlx/backend/cpu/eval.h @@ -7,6 +7,9 @@ namespace mlx::core::cpu { +void set_error_event(Stream s, Event event); +void clear_error_event(Stream s); +void check_error_event(Stream s, Event event); void new_stream(Stream s); void new_thread_unsafe_stream(Stream s); void eval(array& arr); diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index b73937ec38..6668c9307d 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -1,13 +1,14 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/event.h" -#include "mlx/scheduler.h" #include +#include #include #include @@ -131,7 +132,7 @@ class CopyableCudaEvent { void wait(Stream s) { if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this]() mutable { + cpu::get_command_encoder(s).dispatch([*this]() mutable { check_recorded(); event_->wait(); }); @@ -265,7 +266,8 @@ void AtomicEvent::wait(cudaStream_t stream, uint32_t value) { void AtomicEvent::wait(Stream s, uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::wait(s)"); if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + cpu::get_command_encoder(s).dispatch( + [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); @@ -292,8 +294,8 @@ void AtomicEvent::signal(Stream s, uint32_t value) { if (s.device == mlx::core::Device::cpu) { // Signal through a GPU stream so the atomic is updated in GPU - updating // the atomic in CPU sometimes does not get GPU notified. - scheduler::enqueue( - s, [*this, value]() mutable { signal(signal_stream(), value); }); + cpu::get_command_encoder(s).dispatch_unchecked( + [*this, value]() mutable { signal(signal_stream(), value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); @@ -339,6 +341,8 @@ struct EventImpl { // 2. signal value other than 1 has been specified. std::unique_ptr cuda; std::unique_ptr atomic; + std::exception_ptr error; + std::mutex mtx; bool is_created() const { return cuda || atomic; @@ -356,6 +360,35 @@ struct EventImpl { cuda = std::make_unique(d); } } + + void set_error(std::exception_ptr err, uint64_t value) { + { + std::lock_guard lk(mtx); + if (!error) { + error = std::move(err); + } + } + if (atomic) { + atomic->signal(value); + } + } + + void check_error() { + std::exception_ptr err; + { + std::lock_guard lk(mtx); + err = std::move(error); + error = nullptr; + } + if (err) { + std::rethrow_exception(err); + } + } + + std::exception_ptr get_error() { + std::lock_guard lk(mtx); + return error; + } }; } // namespace @@ -367,6 +400,7 @@ Event::Event(Stream s) : stream_(s) { void Event::wait() { auto* event = static_cast(event_.get()); + event->check_error(); assert(event->is_created()); if (event->cuda) { assert(value() == 1); @@ -374,17 +408,35 @@ void Event::wait() { } else { event->atomic->wait(value()); } + event->check_error(); CHECK_CUDA_ERROR(cudaPeekAtLastError()); } void Event::wait(Stream s) { - auto* event = static_cast(event_.get()); - assert(event->is_created()); - if (event->cuda) { - assert(value() == 1); - event->cuda->wait(s); + if (s.device == mlx::core::Device::cpu) { + cpu::get_command_encoder(s).dispatch( + [event_ = event_, value = value()]() mutable { + auto* event = static_cast(event_.get()); + event->check_error(); + assert(event->is_created()); + if (event->cuda) { + assert(value == 1); + event->cuda->wait(); + } else { + event->atomic->wait(value); + } + event->check_error(); + }); } else { - event->atomic->wait(s, value()); + auto* event = static_cast(event_.get()); + event->check_error(); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->atomic->wait(s, value()); + } } } @@ -401,6 +453,7 @@ void Event::signal(Stream s) { bool Event::is_signaled() const { auto* event = static_cast(event_.get()); + event->check_error(); if (!event->is_created()) { return false; } @@ -412,4 +465,12 @@ bool Event::is_signaled() const { } } +void Event::set_error(std::exception_ptr error) { + static_cast(event_.get())->set_error(std::move(error), value()); +} + +std::exception_ptr Event::error() const { + return static_cast(event_.get())->get_error(); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index c6a41f0e60..ceb5a87b19 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -1,30 +1,69 @@ // Copyright © 2025 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" +#include +#include + namespace mlx::core { struct FenceImpl { uint32_t count; cu::AtomicEvent event; + std::exception_ptr error; + std::mutex mtx; }; Fence::Fence(Stream s) { fence_ = std::shared_ptr( - new FenceImpl{0, cu::device(s.device)}, + new FenceImpl{0, cu::device(s.device), nullptr, {}}, [](void* ptr) { delete static_cast(ptr); }); } void Fence::wait(Stream s, const array&) { auto* fence = static_cast(fence_.get()); - fence->event.wait(fence->count); + auto count = fence->count; + if (s.device == Device::cpu) { + cpu::get_command_encoder(s).dispatch([fence_ = fence_, count]() mutable { + auto* fence = static_cast(fence_.get()); + fence->event.wait(count); + std::exception_ptr error; + { + std::lock_guard lk(fence->mtx); + error = fence->error; + } + if (error) { + std::rethrow_exception(error); + } + }); + } else { + fence->event.wait(s, count); + } } void Fence::update(Stream s, const array& a, bool cross_device) { auto* fence = static_cast(fence_.get()); + fence->count++; + auto count = fence->count; + if (s.device == Device::cpu) { + cpu::get_command_encoder(s).dispatch_unchecked( + [event = a.event(), fence_ = fence_, count]() mutable { + auto* fence = static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + if (error) { + std::lock_guard lk(fence->mtx); + if (!fence->error) { + fence->error = std::move(error); + } + } + fence->event.signal(count); + }); + return; + } if (cross_device) { // Move to managed memory if there is a device switch auto& cbuf = @@ -35,8 +74,7 @@ void Fence::update(Stream s, const array& a, bool cross_device) { cu::allocator().move_to_unified_memory(cbuf, encoder.stream()); } } - fence->count++; - fence->event.signal(s, fence->count); + fence->event.signal(s, count); } } // namespace mlx::core diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 77f48f0838..280a7ae4b1 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -1,10 +1,27 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/event.h" -#include "mlx/scheduler.h" +#include "mlx/backend/cpu/encoder.h" namespace mlx::core { +namespace { + +std::shared_ptr message_from_exception(std::exception_ptr error) { + if (!error) { + return std::make_shared("Unknown exception."); + } + try { + std::rethrow_exception(std::move(error)); + } catch (const std::exception& e) { + return std::make_shared(e.what()); + } catch (...) { + return std::make_shared("Unknown exception."); + } +} + +} // namespace + /////////////////////////////////////////////////////////////////////////////// // EventImpl implementations /////////////////////////////////////////////////////////////////////////////// @@ -39,13 +56,53 @@ void EventImpl::set_error(std::shared_ptr error) { std::atomic_store(&error_, std::move(error)); } +void EventImpl::set_error(std::exception_ptr error, uint64_t value) { + auto message = message_from_exception(error); + { + std::lock_guard lk(exception_mtx_); + if (!exception_) { + exception_ = std::move(error); + } + } + if (!std::atomic_load(&error_)) { + std::atomic_store(&error_, std::move(message)); + } + signal(value); +} + void EventImpl::check_error() { + std::exception_ptr exception; + { + std::lock_guard lk(exception_mtx_); + exception = std::move(exception_); + exception_ = nullptr; + } + if (exception) { + std::atomic_exchange(&error_, {}); + std::rethrow_exception(exception); + } + auto error = std::atomic_exchange(&error_, {}); if (error) { throw std::runtime_error(*error); } } +std::exception_ptr EventImpl::exception() const { + { + std::lock_guard lk(exception_mtx_); + if (exception_) { + return exception_; + } + } + + auto error = std::atomic_load(&error_); + if (error) { + return std::make_exception_ptr(std::runtime_error(*error)); + } + return nullptr; +} + } // namespace metal /////////////////////////////////////////////////////////////////////////////// @@ -63,9 +120,8 @@ void Event::wait() { void Event::wait(Stream stream) { auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { - impl->wait(value); - }); + cpu::get_command_encoder(stream).dispatch( + [impl = std::move(impl), value = value()]() { impl->wait(value); }); } else { auto& encoder = metal::get_command_encoder(stream); encoder.wait_event(std::move(impl), value()); @@ -75,9 +131,8 @@ void Event::wait(Stream stream) { void Event::signal(Stream stream) { auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { - impl->signal(value); - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [impl = std::move(impl), value = value()]() { impl->signal(value); }); } else { auto& encoder = metal::get_command_encoder(stream); encoder.signal_event(std::move(impl), value()); @@ -85,8 +140,18 @@ void Event::signal(Stream stream) { } bool Event::is_signaled() const { - auto* mtl_event = static_cast(event_.get())->mtl_event(); - return mtl_event->signaledValue() >= value(); + auto* impl = static_cast(event_.get()); + impl->check_error(); + return impl->mtl_event()->signaledValue() >= value(); +} + +void Event::set_error(std::exception_ptr error) { + static_cast(event_.get()) + ->set_error(std::move(error), value()); +} + +std::exception_ptr Event::error() const { + return static_cast(event_.get())->exception(); } } // namespace mlx::core diff --git a/mlx/backend/metal/event.h b/mlx/backend/metal/event.h index c5c82a7cd3..df6665611c 100644 --- a/mlx/backend/metal/event.h +++ b/mlx/backend/metal/event.h @@ -1,8 +1,12 @@ // Copyright © 2026 Apple Inc. +#pragma once + #include "mlx/backend/metal/device.h" #include "mlx/event.h" +#include + namespace mlx::core::metal { class EventImpl { @@ -13,7 +17,9 @@ class EventImpl { void wait(uint64_t value); void signal(uint64_t value); void set_error(std::shared_ptr error); + void set_error(std::exception_ptr error, uint64_t value); void check_error(); + std::exception_ptr exception() const; const auto& error() const { return error_; @@ -26,6 +32,8 @@ class EventImpl { private: // TODO: Use std::atomic when it gets supported in Xcode. std::shared_ptr error_; + std::exception_ptr exception_; + mutable std::mutex exception_mtx_; NS::SharedPtr mtl_event_; }; diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 6fdd57a5f6..8dad321661 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,9 +1,12 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/metal/device.h" -#include "mlx/scheduler.h" #include "mlx/utils.h" +#include +#include + namespace mlx::core { struct FenceImpl { @@ -33,6 +36,8 @@ struct FenceImpl { uint32_t count{0}; void* fence; std::unique_ptr event; + std::exception_ptr error; + std::mutex error_mtx; std::atomic_uint* cpu_value() { return static_cast( @@ -54,11 +59,28 @@ void Fence::wait(Stream stream, const array& x) { } if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { - auto& f = *static_cast(fence_.get()); - while (f.cpu_value()[0] < count) { - } - }); + cpu::get_command_encoder(stream).dispatch( + [fence_ = fence_, count = f.count]() mutable { + auto& f = *static_cast(fence_.get()); + while (f.cpu_value()[0] < count) { + std::exception_ptr error; + { + std::lock_guard lk(f.error_mtx); + error = f.error; + } + if (error) { + std::rethrow_exception(error); + } + } + std::exception_ptr error; + { + std::lock_guard lk(f.error_mtx); + error = f.error; + } + if (error) { + std::rethrow_exception(error); + } + }); return; } @@ -88,15 +110,34 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { if (!f.use_fast) { f.event->set_value(f.count); - f.event->signal(stream); + if (stream.device == Device::cpu) { + cpu::get_command_encoder(stream).dispatch_unchecked( + [event = x.event(), fence_event = *f.event]() mutable { + auto error = event.valid() ? event.error() : nullptr; + if (error) { + fence_event.set_error(std::move(error)); + } + }); + f.event->signal(stream); + } else { + f.event->signal(stream); + } return; } if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { - auto& f = *static_cast(fence_.get()); - f.cpu_value()[0] = count; - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [fence_ = fence_, count = f.count, event = x.event()]() mutable { + auto& f = *static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + if (error) { + std::lock_guard lk(f.error_mtx); + if (!f.error) { + f.error = std::move(error); + } + } + f.cpu_value()[0] = count; + }); return; } diff --git a/mlx/backend/no_gpu/event.cpp b/mlx/backend/no_gpu/event.cpp index 6dde047ab4..22e3a339fc 100644 --- a/mlx/backend/no_gpu/event.cpp +++ b/mlx/backend/no_gpu/event.cpp @@ -1,8 +1,9 @@ // Copyright © 2024 Apple Inc. #include "mlx/event.h" -#include "mlx/scheduler.h" +#include "mlx/backend/cpu/encoder.h" +#include #include #include @@ -10,6 +11,7 @@ namespace mlx::core { struct EventCounter { uint64_t value{0}; + std::exception_ptr error; std::mutex mtx; std::condition_variable cv; }; @@ -22,18 +24,21 @@ Event::Event(Stream stream) : stream_(stream) { void Event::wait() { auto ec = static_cast(event_.get()); std::unique_lock lk(ec->mtx); - if (ec->value >= value()) { - return; + ec->cv.wait( + lk, [value = value(), ec] { return ec->value >= value || ec->error; }); + if (ec->error) { + auto error = std::move(ec->error); + ec->error = nullptr; + std::rethrow_exception(error); } - ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); } void Event::wait(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { wait(); }); + cpu::get_command_encoder(stream).dispatch([*this]() mutable { wait(); }); } void Event::signal(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { + cpu::get_command_encoder(stream).dispatch_unchecked([*this]() mutable { auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); @@ -45,9 +50,34 @@ void Event::signal(Stream stream) { bool Event::is_signaled() const { auto ec = static_cast(event_.get()); + std::exception_ptr error; { std::lock_guard lk(ec->mtx); - return (ec->value >= value()); + if (ec->error) { + error = std::move(ec->error); + ec->error = nullptr; + } else { + return (ec->value >= value()); + } } + std::rethrow_exception(error); +} + +void Event::set_error(std::exception_ptr error) { + auto ec = static_cast(event_.get()); + { + std::lock_guard lk(ec->mtx); + if (!ec->error) { + ec->error = std::move(error); + } + ec->value = std::max(ec->value, value()); + } + ec->cv.notify_all(); +} + +std::exception_ptr Event::error() const { + auto ec = static_cast(event_.get()); + std::lock_guard lk(ec->mtx); + return ec->error; } } // namespace mlx::core diff --git a/mlx/backend/no_gpu/fence.cpp b/mlx/backend/no_gpu/fence.cpp index cd66d23cfe..392cce4df0 100644 --- a/mlx/backend/no_gpu/fence.cpp +++ b/mlx/backend/no_gpu/fence.cpp @@ -1,16 +1,18 @@ // Copyright © 2024 Apple Inc. #include +#include #include +#include "mlx/backend/cpu/encoder.h" #include "mlx/fence.h" -#include "mlx/scheduler.h" namespace mlx::core { struct FenceImpl { uint32_t count{0}; uint32_t value{0}; + std::exception_ptr error; std::mutex mtx; std::condition_variable cv; }; @@ -23,29 +25,35 @@ Fence::Fence(Stream) { void Fence::wait(Stream stream, const array&) { auto& f = *static_cast(fence_.get()); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { - auto& f = *static_cast(fence_.get()); - std::unique_lock lk(f.mtx); - if (f.value >= count) { - return; - } - f.cv.wait(lk, [&f, count] { return f.value >= count; }); - }); + cpu::get_command_encoder(stream).dispatch( + [count = f.count, fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + std::unique_lock lk(f.mtx); + f.cv.wait(lk, [&f, count] { return f.value >= count || f.error; }); + if (f.error) { + std::rethrow_exception(f.error); + } + }); } else { throw std::runtime_error("[Fence::wait] Invalid stream."); } } -void Fence::update(Stream stream, const array&, bool) { +void Fence::update(Stream stream, const array& x, bool) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { - auto& f = *static_cast(fence_.get()); - std::unique_lock lk(f.mtx); - f.value = count; - f.cv.notify_all(); - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [count = f.count, event = x.event(), fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + std::unique_lock lk(f.mtx); + if (error && !f.error) { + f.error = std::move(error); + } + f.value = count; + f.cv.notify_all(); + }); } else { throw std::runtime_error("[Fence::update] Invalid stream."); } diff --git a/mlx/event.h b/mlx/event.h index 66a6a75df5..ca27b404cf 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include #include @@ -26,6 +27,12 @@ class Event { // Check if the event has been signaled at its current value bool is_signaled() const; + // Set an error on the event to be raised by wait(). + void set_error(std::exception_ptr error); + + // Return the current error without clearing it. + std::exception_ptr error() const; + // Check if the event is valid bool valid() const { return event_ != nullptr; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 87f8236db1..613f606471 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -58,6 +58,26 @@ class Synchronizer : public Primitive { DEFINE_NAME(Synchronize); }; +class CpuErrorEventGuard { + public: + void set(Stream stream, Event event) { + if (stream.device != Device::cpu) { + return; + } + cpu::set_error_event(stream, std::move(event)); + streams_.insert(stream); + } + + ~CpuErrorEventGuard() { + for (auto& stream : streams_) { + cpu::clear_error_event(stream); + } + } + + private: + std::set streams_; +}; + // Initialize the static tracing members from transforms_impl.h // // These are used to implement the in_tracing() function the returns true if we @@ -225,6 +245,7 @@ array eval_impl(std::vector outputs, bool async) { } std::set open_streams; + CpuErrorEventGuard cpu_error_events; while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); @@ -232,6 +253,7 @@ array eval_impl(std::vector outputs, bool async) { auto stream = arr.primitive().stream(); open_streams.insert(stream); + Event error_event; if (async) { // Lookup corresponding event auto e = events.find(stream.index); @@ -239,11 +261,15 @@ array eval_impl(std::vector outputs, bool async) { e = events.emplace(stream.index, Event{stream}).first; } e->second.set_value(1); + error_event = e->second; arr.attach_event(e->second); for (auto& s : arr.siblings()) { s.attach_event(e->second); } + } else { + error_event = synchronizer.event(); } + cpu_error_events.set(stream, error_event); for (auto& in : arr.inputs()) { if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { @@ -252,7 +278,22 @@ array eval_impl(std::vector outputs, bool async) { // output arrays stream fences[it->second.first].wait(stream, in); } else if (in.event().valid()) { - if (in.event().is_signaled()) { + if (async && stream.device == Device::cpu) { + if (in.event().stream() == stream) { + cpu::check_error_event(stream, in.event()); + } else { + in.event().wait(stream); + } + } else if (async) { + if (auto error = in.event().error()) { + error_event.set_error(std::move(error)); + } else if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { + // Use event to wait across async eval + in.event().wait(stream); + } + } else if (in.event().is_signaled()) { in.detach_event(); } else if (in.event().stream() != stream) { // Use event to wait across async eval diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 10fb63ea63..148e564c99 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -88,6 +88,35 @@ def test_load_npy_dtype(self): with self.assertRaises(Exception): out = mx.load(save_file, stream=mx.cpu) + def test_load_npy_read_error(self): + save_file = os.path.join(self.test_dir, "truncated.npy") + expected = np.arange(16, dtype=np.float32) + np.save(save_file, expected) + with open(save_file, "r+b") as f: + f.truncate(os.path.getsize(save_file) - expected.nbytes) + + out = mx.load(save_file, stream=mx.cpu) + with self.assertRaises(RuntimeError): + np.array(out) + + def test_async_load_npy_read_error_across_streams(self): + save_file = os.path.join(self.test_dir, "truncated_async.npy") + expected = np.arange(16, dtype=np.float32) + np.save(save_file, expected) + with open(save_file, "r+b") as f: + f.truncate(os.path.getsize(save_file) - expected.nbytes) + + producer_stream = mx.new_stream(mx.cpu) + consumer_stream = mx.new_stream(mx.cpu) + out = mx.add( + mx.load(save_file, stream=producer_stream), + 1.0, + stream=consumer_stream, + ) + mx.async_eval(out) + with self.assertRaises(RuntimeError): + np.array(out) + def test_save_and_load_safetensors(self): test_file = os.path.join(self.test_dir, "test.safetensors") with self.assertRaises(Exception): diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 491ca158c1..92b813e6aa 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -1,8 +1,11 @@ // Copyright © 2023 Apple Inc. +#include #include #include +#include #include +#include #include #include "doctest/doctest.h" @@ -10,11 +13,71 @@ #include "mlx/mlx.h" using namespace mlx::core; +using namespace std::chrono_literals; std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name).string(); } +class BlockingFailReader : public io::Reader { + public: + explicit BlockingFailReader(std::vector data) + : data_(std::move(data)), + release_(release_promise_.get_future().share()) {} + + bool is_open() const override { + return true; + } + + bool good() const override { + return true; + } + + size_t tell() override { + return pos_; + } + + void seek(int64_t off, std::ios_base::seekdir way) override { + if (way == std::ios_base::beg) { + pos_ = off; + } else if (way == std::ios_base::cur) { + pos_ += off; + } else { + pos_ = data_.size() + off; + } + } + + void read(char* data, size_t n) override { + std::copy(data_.begin() + pos_, data_.begin() + pos_ + n, data); + pos_ += n; + } + + void read(char*, size_t, size_t) override { + release_.wait(); + throw std::runtime_error("[read] blocked read failed"); + } + + std::string label() const override { + return "blocking fail reader"; + } + + void release() { + release_promise_.set_value(); + } + + private: + std::vector data_; + size_t pos_{0}; + std::promise release_promise_; + std::shared_future release_; +}; + +std::vector read_file(const std::string& path) { + std::ifstream in(path, std::ios::binary); + return std::vector( + std::istreambuf_iterator(in), std::istreambuf_iterator()); +} + TEST_CASE("test save_safetensors") { std::string file_path = get_temp_file("test_arr.safetensors"); auto map = std::unordered_map(); @@ -342,3 +405,84 @@ TEST_CASE("test single array serialization") { CHECK(array_equal(a, b).item()); } } + +TEST_CASE("test load propagates CPU read errors") { + std::string file_path = get_temp_file("test_truncated_arr.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto a = load(file_path, Device::cpu); + CHECK_THROWS_AS(eval(a), std::runtime_error); +} + +TEST_CASE("test load propagates CPU read errors across streams") { + auto good_stream = new_stream(Device::cpu); + auto bad_stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto good = arange(4, float32, good_stream); + auto bad = load(file_path, bad_stream); + CHECK_THROWS_AS(eval(good, bad), std::runtime_error); +} + +TEST_CASE("test async load propagates CPU read errors across streams") { + auto consumer_stream = new_stream(Device::cpu); + auto producer_stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_async_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto bad = load(file_path, producer_stream); + auto out = add(bad, array(1.0f), consumer_stream); + async_eval(out); + CHECK_THROWS_AS(out.wait(), std::runtime_error); +} + +TEST_CASE("test async load error poisons same-stream dependent eval") { + auto stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_same_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto bad = load(file_path, stream); + async_eval(bad); + for (int i = 0; i < 100 && !bad.event().error(); ++i) { + std::this_thread::sleep_for(10ms); + } + REQUIRE(bad.event().error()); + + auto out = add(bad, array(1.0f), stream); + async_eval(out); + CHECK_THROWS_AS(out.wait(), std::runtime_error); +} + +TEST_CASE("test async load error poisons pending same-stream dependent eval") { + auto stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_arr_blocking_reader.npy"); + save(file_path, arange(16, float32, Device::cpu)); + auto reader = std::make_shared(read_file(file_path)); + + auto bad = load(reader, stream); + async_eval(bad); + + auto out = add(bad, array(1.0f), stream); + async_eval(out); + + reader->release(); + CHECK_THROWS_AS(out.wait(), std::runtime_error); + CHECK_THROWS_AS(bad.wait(), std::runtime_error); +} diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index 3a8f7e86b7..2e6d991176 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -1,11 +1,16 @@ // Copyright © 2023 Apple Inc. +#include +#include +#include + #include "doctest/doctest.h" #include "mlx/mlx.h" #include "mlx/scheduler.h" using namespace mlx::core; +using namespace std::chrono_literals; TEST_CASE("test stream management") { auto s1 = default_stream(default_device()); @@ -196,6 +201,44 @@ TEST_CASE("test asynchronous launch") { CHECK_EQ(x, 10); } +TEST_CASE("test event error wakes waiters") { + auto e = Event{default_stream(Device::cpu)}; + e.set_value(1); + e.signal(default_stream(Device::cpu)); + synchronize(default_stream(Device::cpu)); + + e.set_value(2); + + auto waiter = std::async(std::launch::async, [e]() mutable { e.wait(); }); + + std::this_thread::sleep_for(10ms); + e.set_error(std::make_exception_ptr(std::runtime_error("test error"))); + + if (waiter.wait_for(1s) == std::future_status::timeout) { + e.signal(default_stream(Device::cpu)); + FAIL("event waiter was not woken by set_error"); + } + CHECK_THROWS_AS(waiter.get(), std::runtime_error); +} + +TEST_CASE("test event error leaves event signaled after consumption") { + auto e = Event{default_stream(Device::cpu)}; + e.set_value(1); + e.signal(default_stream(Device::cpu)); + synchronize(default_stream(Device::cpu)); + + e.set_value(2); + e.set_error(std::make_exception_ptr(std::runtime_error("test error"))); + CHECK_THROWS_AS(e.wait(), std::runtime_error); + + auto waiter = std::async(std::launch::async, [e]() mutable { e.wait(); }); + if (waiter.wait_for(1s) == std::future_status::timeout) { + e.signal(default_stream(Device::cpu)); + FAIL("event was not signaled after error was consumed"); + } + CHECK_NOTHROW(waiter.get()); +} + TEST_CASE("test stream placement") { auto s1 = default_stream(Device::cpu); auto s2 = new_stream(Device::cpu);