diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index ce41963de7..2f3fe8755b 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -3,8 +3,8 @@ #include #include +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" -#include "mlx/scheduler.h" namespace { @@ -51,7 +51,8 @@ void Load::eval_cpu(const std::vector& inputs, array& out) { } }; auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); - scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); }); + cpu::get_command_encoder(stream()).dispatch( + [fut = std::move(fut)]() { fut.get(); }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index cd015623f6..df8e25647b 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -2,9 +2,11 @@ #pragma once +#include #include #include "mlx/array.h" +#include "mlx/event.h" #include "mlx/scheduler.h" namespace mlx::core::cpu { @@ -36,28 +38,80 @@ struct MLX_API CommandEncoder { std::make_move_iterator(arrays.end())); } + void set_error_event(Event event) { + event_ = std::move(event); + } + std::vector& temporaries() { return temporaries_; } template void dispatch(F&& f, Args&&... args) { + dispatch_impl(true, std::forward(f), std::forward(args)...); + } + + template + void dispatch_unchecked(F&& f, Args&&... args) { + dispatch_impl(false, std::forward(f), std::forward(args)...); + } + + private: + template + void dispatch_impl(bool skip_on_error, F&& f, Args&&... args) { num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; auto task = std::bind(std::forward(f), std::forward(args)...); + auto event = event_; if (num_ops_ == 0) { scheduler::notify_new_task(stream_); - auto task_wrap = [s = stream_, task = std::move(task)]() mutable { - task(); - scheduler::notify_task_completion(s); + auto task_wrap = [s = stream_, + event = std::move(event), + skip_on_error, + task = std::move(task)]() mutable { + struct CompletionNotifier { + Stream stream; + ~CompletionNotifier() { + scheduler::notify_task_completion(stream); + } + } completion{s}; + if (skip_on_error && event.valid() && event.error()) { + return; + } + try { + task(); + } catch (...) { + if (event.valid()) { + event.set_error(std::current_exception()); + } else { + throw; + } + } }; scheduler::enqueue(stream_, std::move(task_wrap)); } else { - scheduler::enqueue(stream_, std::move(task)); + scheduler::enqueue( + stream_, + [event = std::move(event), + skip_on_error, + task = std::move(task)]() mutable { + if (skip_on_error && event.valid() && event.error()) { + return; + } + try { + task(); + } catch (...) { + if (event.valid()) { + event.set_error(std::current_exception()); + } else { + throw; + } + } + }); } } - private: Stream stream_; + Event event_; std::vector temporaries_; int num_ops_{0}; }; diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 354820f0fe..2212305f5e 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -7,6 +7,33 @@ namespace mlx::core::cpu { +void set_error_event(Stream s, Event event) { + get_command_encoder(s).set_error_event(std::move(event)); +} + +void clear_error_event(Stream s) { + auto& encoders = get_command_encoders(); + auto it = encoders.find(s.index); + if (it != encoders.end()) { + it->second.set_error_event(Event{}); + return; + } + + auto& global_encoders = get_global_command_encoders(); + it = global_encoders.find(s.index); + if (it != global_encoders.end()) { + it->second.set_error_event(Event{}); + } +} + +void check_error_event(Stream s, Event event) { + get_command_encoder(s).dispatch([event = std::move(event)]() { + if (auto error = event.error()) { + std::rethrow_exception(error); + } + }); +} + void new_stream(Stream s) { auto& encoders = get_command_encoders(); encoders.try_emplace(s.index, s); diff --git a/mlx/backend/cpu/eval.h b/mlx/backend/cpu/eval.h index 775f3e46f0..7bd75a87d1 100644 --- a/mlx/backend/cpu/eval.h +++ b/mlx/backend/cpu/eval.h @@ -7,6 +7,9 @@ namespace mlx::core::cpu { +void set_error_event(Stream s, Event event); +void clear_error_event(Stream s); +void check_error_event(Stream s, Event event); void new_stream(Stream s); void new_thread_unsafe_stream(Stream s); void eval(array& arr); diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index b73937ec38..6668c9307d 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -1,13 +1,14 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" #include "mlx/backend/gpu/device_info.h" #include "mlx/event.h" -#include "mlx/scheduler.h" #include +#include #include #include @@ -131,7 +132,7 @@ class CopyableCudaEvent { void wait(Stream s) { if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this]() mutable { + cpu::get_command_encoder(s).dispatch([*this]() mutable { check_recorded(); event_->wait(); }); @@ -265,7 +266,8 @@ void AtomicEvent::wait(cudaStream_t stream, uint32_t value) { void AtomicEvent::wait(Stream s, uint32_t value) { nvtx3::scoped_range r("cu::AtomicEvent::wait(s)"); if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + cpu::get_command_encoder(s).dispatch( + [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); @@ -292,8 +294,8 @@ void AtomicEvent::signal(Stream s, uint32_t value) { if (s.device == mlx::core::Device::cpu) { // Signal through a GPU stream so the atomic is updated in GPU - updating // the atomic in CPU sometimes does not get GPU notified. - scheduler::enqueue( - s, [*this, value]() mutable { signal(signal_stream(), value); }); + cpu::get_command_encoder(s).dispatch_unchecked( + [*this, value]() mutable { signal(signal_stream(), value); }); } else { auto& encoder = get_command_encoder(s); encoder.commit(); @@ -339,6 +341,8 @@ struct EventImpl { // 2. signal value other than 1 has been specified. std::unique_ptr cuda; std::unique_ptr atomic; + std::exception_ptr error; + std::mutex mtx; bool is_created() const { return cuda || atomic; @@ -356,6 +360,35 @@ struct EventImpl { cuda = std::make_unique(d); } } + + void set_error(std::exception_ptr err, uint64_t value) { + { + std::lock_guard lk(mtx); + if (!error) { + error = std::move(err); + } + } + if (atomic) { + atomic->signal(value); + } + } + + void check_error() { + std::exception_ptr err; + { + std::lock_guard lk(mtx); + err = std::move(error); + error = nullptr; + } + if (err) { + std::rethrow_exception(err); + } + } + + std::exception_ptr get_error() { + std::lock_guard lk(mtx); + return error; + } }; } // namespace @@ -367,6 +400,7 @@ Event::Event(Stream s) : stream_(s) { void Event::wait() { auto* event = static_cast(event_.get()); + event->check_error(); assert(event->is_created()); if (event->cuda) { assert(value() == 1); @@ -374,17 +408,35 @@ void Event::wait() { } else { event->atomic->wait(value()); } + event->check_error(); CHECK_CUDA_ERROR(cudaPeekAtLastError()); } void Event::wait(Stream s) { - auto* event = static_cast(event_.get()); - assert(event->is_created()); - if (event->cuda) { - assert(value() == 1); - event->cuda->wait(s); + if (s.device == mlx::core::Device::cpu) { + cpu::get_command_encoder(s).dispatch( + [event_ = event_, value = value()]() mutable { + auto* event = static_cast(event_.get()); + event->check_error(); + assert(event->is_created()); + if (event->cuda) { + assert(value == 1); + event->cuda->wait(); + } else { + event->atomic->wait(value); + } + event->check_error(); + }); } else { - event->atomic->wait(s, value()); + auto* event = static_cast(event_.get()); + event->check_error(); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->atomic->wait(s, value()); + } } } @@ -401,6 +453,7 @@ void Event::signal(Stream s) { bool Event::is_signaled() const { auto* event = static_cast(event_.get()); + event->check_error(); if (!event->is_created()) { return false; } @@ -412,4 +465,12 @@ bool Event::is_signaled() const { } } +void Event::set_error(std::exception_ptr error) { + static_cast(event_.get())->set_error(std::move(error), value()); +} + +std::exception_ptr Event::error() const { + return static_cast(event_.get())->get_error(); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index c6a41f0e60..ceb5a87b19 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -1,30 +1,69 @@ // Copyright © 2025 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" +#include +#include + namespace mlx::core { struct FenceImpl { uint32_t count; cu::AtomicEvent event; + std::exception_ptr error; + std::mutex mtx; }; Fence::Fence(Stream s) { fence_ = std::shared_ptr( - new FenceImpl{0, cu::device(s.device)}, + new FenceImpl{0, cu::device(s.device), nullptr, {}}, [](void* ptr) { delete static_cast(ptr); }); } void Fence::wait(Stream s, const array&) { auto* fence = static_cast(fence_.get()); - fence->event.wait(fence->count); + auto count = fence->count; + if (s.device == Device::cpu) { + cpu::get_command_encoder(s).dispatch([fence_ = fence_, count]() mutable { + auto* fence = static_cast(fence_.get()); + fence->event.wait(count); + std::exception_ptr error; + { + std::lock_guard lk(fence->mtx); + error = fence->error; + } + if (error) { + std::rethrow_exception(error); + } + }); + } else { + fence->event.wait(s, count); + } } void Fence::update(Stream s, const array& a, bool cross_device) { auto* fence = static_cast(fence_.get()); + fence->count++; + auto count = fence->count; + if (s.device == Device::cpu) { + cpu::get_command_encoder(s).dispatch_unchecked( + [event = a.event(), fence_ = fence_, count]() mutable { + auto* fence = static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + if (error) { + std::lock_guard lk(fence->mtx); + if (!fence->error) { + fence->error = std::move(error); + } + } + fence->event.signal(count); + }); + return; + } if (cross_device) { // Move to managed memory if there is a device switch auto& cbuf = @@ -35,8 +74,7 @@ void Fence::update(Stream s, const array& a, bool cross_device) { cu::allocator().move_to_unified_memory(cbuf, encoder.stream()); } } - fence->count++; - fence->event.signal(s, fence->count); + fence->event.signal(s, count); } } // namespace mlx::core diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 77f48f0838..280a7ae4b1 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -1,10 +1,27 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/event.h" -#include "mlx/scheduler.h" +#include "mlx/backend/cpu/encoder.h" namespace mlx::core { +namespace { + +std::shared_ptr message_from_exception(std::exception_ptr error) { + if (!error) { + return std::make_shared("Unknown exception."); + } + try { + std::rethrow_exception(std::move(error)); + } catch (const std::exception& e) { + return std::make_shared(e.what()); + } catch (...) { + return std::make_shared("Unknown exception."); + } +} + +} // namespace + /////////////////////////////////////////////////////////////////////////////// // EventImpl implementations /////////////////////////////////////////////////////////////////////////////// @@ -39,13 +56,53 @@ void EventImpl::set_error(std::shared_ptr error) { std::atomic_store(&error_, std::move(error)); } +void EventImpl::set_error(std::exception_ptr error, uint64_t value) { + auto message = message_from_exception(error); + { + std::lock_guard lk(exception_mtx_); + if (!exception_) { + exception_ = std::move(error); + } + } + if (!std::atomic_load(&error_)) { + std::atomic_store(&error_, std::move(message)); + } + signal(value); +} + void EventImpl::check_error() { + std::exception_ptr exception; + { + std::lock_guard lk(exception_mtx_); + exception = std::move(exception_); + exception_ = nullptr; + } + if (exception) { + std::atomic_exchange(&error_, {}); + std::rethrow_exception(exception); + } + auto error = std::atomic_exchange(&error_, {}); if (error) { throw std::runtime_error(*error); } } +std::exception_ptr EventImpl::exception() const { + { + std::lock_guard lk(exception_mtx_); + if (exception_) { + return exception_; + } + } + + auto error = std::atomic_load(&error_); + if (error) { + return std::make_exception_ptr(std::runtime_error(*error)); + } + return nullptr; +} + } // namespace metal /////////////////////////////////////////////////////////////////////////////// @@ -63,9 +120,8 @@ void Event::wait() { void Event::wait(Stream stream) { auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { - impl->wait(value); - }); + cpu::get_command_encoder(stream).dispatch( + [impl = std::move(impl), value = value()]() { impl->wait(value); }); } else { auto& encoder = metal::get_command_encoder(stream); encoder.wait_event(std::move(impl), value()); @@ -75,9 +131,8 @@ void Event::wait(Stream stream) { void Event::signal(Stream stream) { auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { - impl->signal(value); - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [impl = std::move(impl), value = value()]() { impl->signal(value); }); } else { auto& encoder = metal::get_command_encoder(stream); encoder.signal_event(std::move(impl), value()); @@ -85,8 +140,18 @@ void Event::signal(Stream stream) { } bool Event::is_signaled() const { - auto* mtl_event = static_cast(event_.get())->mtl_event(); - return mtl_event->signaledValue() >= value(); + auto* impl = static_cast(event_.get()); + impl->check_error(); + return impl->mtl_event()->signaledValue() >= value(); +} + +void Event::set_error(std::exception_ptr error) { + static_cast(event_.get()) + ->set_error(std::move(error), value()); +} + +std::exception_ptr Event::error() const { + return static_cast(event_.get())->exception(); } } // namespace mlx::core diff --git a/mlx/backend/metal/event.h b/mlx/backend/metal/event.h index c5c82a7cd3..df6665611c 100644 --- a/mlx/backend/metal/event.h +++ b/mlx/backend/metal/event.h @@ -1,8 +1,12 @@ // Copyright © 2026 Apple Inc. +#pragma once + #include "mlx/backend/metal/device.h" #include "mlx/event.h" +#include + namespace mlx::core::metal { class EventImpl { @@ -13,7 +17,9 @@ class EventImpl { void wait(uint64_t value); void signal(uint64_t value); void set_error(std::shared_ptr error); + void set_error(std::exception_ptr error, uint64_t value); void check_error(); + std::exception_ptr exception() const; const auto& error() const { return error_; @@ -26,6 +32,8 @@ class EventImpl { private: // TODO: Use std::atomic when it gets supported in Xcode. std::shared_ptr error_; + std::exception_ptr exception_; + mutable std::mutex exception_mtx_; NS::SharedPtr mtl_event_; }; diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 6fdd57a5f6..8dad321661 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,9 +1,12 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/metal/device.h" -#include "mlx/scheduler.h" #include "mlx/utils.h" +#include +#include + namespace mlx::core { struct FenceImpl { @@ -33,6 +36,8 @@ struct FenceImpl { uint32_t count{0}; void* fence; std::unique_ptr event; + std::exception_ptr error; + std::mutex error_mtx; std::atomic_uint* cpu_value() { return static_cast( @@ -54,11 +59,28 @@ void Fence::wait(Stream stream, const array& x) { } if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { - auto& f = *static_cast(fence_.get()); - while (f.cpu_value()[0] < count) { - } - }); + cpu::get_command_encoder(stream).dispatch( + [fence_ = fence_, count = f.count]() mutable { + auto& f = *static_cast(fence_.get()); + while (f.cpu_value()[0] < count) { + std::exception_ptr error; + { + std::lock_guard lk(f.error_mtx); + error = f.error; + } + if (error) { + std::rethrow_exception(error); + } + } + std::exception_ptr error; + { + std::lock_guard lk(f.error_mtx); + error = f.error; + } + if (error) { + std::rethrow_exception(error); + } + }); return; } @@ -88,15 +110,34 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { if (!f.use_fast) { f.event->set_value(f.count); - f.event->signal(stream); + if (stream.device == Device::cpu) { + cpu::get_command_encoder(stream).dispatch_unchecked( + [event = x.event(), fence_event = *f.event]() mutable { + auto error = event.valid() ? event.error() : nullptr; + if (error) { + fence_event.set_error(std::move(error)); + } + }); + f.event->signal(stream); + } else { + f.event->signal(stream); + } return; } if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { - auto& f = *static_cast(fence_.get()); - f.cpu_value()[0] = count; - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [fence_ = fence_, count = f.count, event = x.event()]() mutable { + auto& f = *static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + if (error) { + std::lock_guard lk(f.error_mtx); + if (!f.error) { + f.error = std::move(error); + } + } + f.cpu_value()[0] = count; + }); return; } diff --git a/mlx/backend/no_gpu/event.cpp b/mlx/backend/no_gpu/event.cpp index 6dde047ab4..22e3a339fc 100644 --- a/mlx/backend/no_gpu/event.cpp +++ b/mlx/backend/no_gpu/event.cpp @@ -1,8 +1,9 @@ // Copyright © 2024 Apple Inc. #include "mlx/event.h" -#include "mlx/scheduler.h" +#include "mlx/backend/cpu/encoder.h" +#include #include #include @@ -10,6 +11,7 @@ namespace mlx::core { struct EventCounter { uint64_t value{0}; + std::exception_ptr error; std::mutex mtx; std::condition_variable cv; }; @@ -22,18 +24,21 @@ Event::Event(Stream stream) : stream_(stream) { void Event::wait() { auto ec = static_cast(event_.get()); std::unique_lock lk(ec->mtx); - if (ec->value >= value()) { - return; + ec->cv.wait( + lk, [value = value(), ec] { return ec->value >= value || ec->error; }); + if (ec->error) { + auto error = std::move(ec->error); + ec->error = nullptr; + std::rethrow_exception(error); } - ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); } void Event::wait(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { wait(); }); + cpu::get_command_encoder(stream).dispatch([*this]() mutable { wait(); }); } void Event::signal(Stream stream) { - scheduler::enqueue(stream, [*this]() mutable { + cpu::get_command_encoder(stream).dispatch_unchecked([*this]() mutable { auto ec = static_cast(event_.get()); { std::lock_guard lk(ec->mtx); @@ -45,9 +50,34 @@ void Event::signal(Stream stream) { bool Event::is_signaled() const { auto ec = static_cast(event_.get()); + std::exception_ptr error; { std::lock_guard lk(ec->mtx); - return (ec->value >= value()); + if (ec->error) { + error = std::move(ec->error); + ec->error = nullptr; + } else { + return (ec->value >= value()); + } } + std::rethrow_exception(error); +} + +void Event::set_error(std::exception_ptr error) { + auto ec = static_cast(event_.get()); + { + std::lock_guard lk(ec->mtx); + if (!ec->error) { + ec->error = std::move(error); + } + ec->value = std::max(ec->value, value()); + } + ec->cv.notify_all(); +} + +std::exception_ptr Event::error() const { + auto ec = static_cast(event_.get()); + std::lock_guard lk(ec->mtx); + return ec->error; } } // namespace mlx::core diff --git a/mlx/backend/no_gpu/fence.cpp b/mlx/backend/no_gpu/fence.cpp index cd66d23cfe..392cce4df0 100644 --- a/mlx/backend/no_gpu/fence.cpp +++ b/mlx/backend/no_gpu/fence.cpp @@ -1,16 +1,18 @@ // Copyright © 2024 Apple Inc. #include +#include #include +#include "mlx/backend/cpu/encoder.h" #include "mlx/fence.h" -#include "mlx/scheduler.h" namespace mlx::core { struct FenceImpl { uint32_t count{0}; uint32_t value{0}; + std::exception_ptr error; std::mutex mtx; std::condition_variable cv; }; @@ -23,29 +25,35 @@ Fence::Fence(Stream) { void Fence::wait(Stream stream, const array&) { auto& f = *static_cast(fence_.get()); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { - auto& f = *static_cast(fence_.get()); - std::unique_lock lk(f.mtx); - if (f.value >= count) { - return; - } - f.cv.wait(lk, [&f, count] { return f.value >= count; }); - }); + cpu::get_command_encoder(stream).dispatch( + [count = f.count, fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + std::unique_lock lk(f.mtx); + f.cv.wait(lk, [&f, count] { return f.value >= count || f.error; }); + if (f.error) { + std::rethrow_exception(f.error); + } + }); } else { throw std::runtime_error("[Fence::wait] Invalid stream."); } } -void Fence::update(Stream stream, const array&, bool) { +void Fence::update(Stream stream, const array& x, bool) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { - auto& f = *static_cast(fence_.get()); - std::unique_lock lk(f.mtx); - f.value = count; - f.cv.notify_all(); - }); + cpu::get_command_encoder(stream).dispatch_unchecked( + [count = f.count, event = x.event(), fence_ = fence_]() mutable { + auto& f = *static_cast(fence_.get()); + auto error = event.valid() ? event.error() : nullptr; + std::unique_lock lk(f.mtx); + if (error && !f.error) { + f.error = std::move(error); + } + f.value = count; + f.cv.notify_all(); + }); } else { throw std::runtime_error("[Fence::update] Invalid stream."); } diff --git a/mlx/event.h b/mlx/event.h index 66a6a75df5..ca27b404cf 100644 --- a/mlx/event.h +++ b/mlx/event.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include #include @@ -26,6 +27,12 @@ class Event { // Check if the event has been signaled at its current value bool is_signaled() const; + // Set an error on the event to be raised by wait(). + void set_error(std::exception_ptr error); + + // Return the current error without clearing it. + std::exception_ptr error() const; + // Check if the event is valid bool valid() const { return event_ != nullptr; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 87f8236db1..613f606471 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -58,6 +58,26 @@ class Synchronizer : public Primitive { DEFINE_NAME(Synchronize); }; +class CpuErrorEventGuard { + public: + void set(Stream stream, Event event) { + if (stream.device != Device::cpu) { + return; + } + cpu::set_error_event(stream, std::move(event)); + streams_.insert(stream); + } + + ~CpuErrorEventGuard() { + for (auto& stream : streams_) { + cpu::clear_error_event(stream); + } + } + + private: + std::set streams_; +}; + // Initialize the static tracing members from transforms_impl.h // // These are used to implement the in_tracing() function the returns true if we @@ -225,6 +245,7 @@ array eval_impl(std::vector outputs, bool async) { } std::set open_streams; + CpuErrorEventGuard cpu_error_events; while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); @@ -232,6 +253,7 @@ array eval_impl(std::vector outputs, bool async) { auto stream = arr.primitive().stream(); open_streams.insert(stream); + Event error_event; if (async) { // Lookup corresponding event auto e = events.find(stream.index); @@ -239,11 +261,15 @@ array eval_impl(std::vector outputs, bool async) { e = events.emplace(stream.index, Event{stream}).first; } e->second.set_value(1); + error_event = e->second; arr.attach_event(e->second); for (auto& s : arr.siblings()) { s.attach_event(e->second); } + } else { + error_event = synchronizer.event(); } + cpu_error_events.set(stream, error_event); for (auto& in : arr.inputs()) { if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) { @@ -252,7 +278,22 @@ array eval_impl(std::vector outputs, bool async) { // output arrays stream fences[it->second.first].wait(stream, in); } else if (in.event().valid()) { - if (in.event().is_signaled()) { + if (async && stream.device == Device::cpu) { + if (in.event().stream() == stream) { + cpu::check_error_event(stream, in.event()); + } else { + in.event().wait(stream); + } + } else if (async) { + if (auto error = in.event().error()) { + error_event.set_error(std::move(error)); + } else if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { + // Use event to wait across async eval + in.event().wait(stream); + } + } else if (in.event().is_signaled()) { in.detach_event(); } else if (in.event().stream() != stream) { // Use event to wait across async eval diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 10fb63ea63..148e564c99 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -88,6 +88,35 @@ def test_load_npy_dtype(self): with self.assertRaises(Exception): out = mx.load(save_file, stream=mx.cpu) + def test_load_npy_read_error(self): + save_file = os.path.join(self.test_dir, "truncated.npy") + expected = np.arange(16, dtype=np.float32) + np.save(save_file, expected) + with open(save_file, "r+b") as f: + f.truncate(os.path.getsize(save_file) - expected.nbytes) + + out = mx.load(save_file, stream=mx.cpu) + with self.assertRaises(RuntimeError): + np.array(out) + + def test_async_load_npy_read_error_across_streams(self): + save_file = os.path.join(self.test_dir, "truncated_async.npy") + expected = np.arange(16, dtype=np.float32) + np.save(save_file, expected) + with open(save_file, "r+b") as f: + f.truncate(os.path.getsize(save_file) - expected.nbytes) + + producer_stream = mx.new_stream(mx.cpu) + consumer_stream = mx.new_stream(mx.cpu) + out = mx.add( + mx.load(save_file, stream=producer_stream), + 1.0, + stream=consumer_stream, + ) + mx.async_eval(out) + with self.assertRaises(RuntimeError): + np.array(out) + def test_save_and_load_safetensors(self): test_file = os.path.join(self.test_dir, "test.safetensors") with self.assertRaises(Exception): diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 491ca158c1..92b813e6aa 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -1,8 +1,11 @@ // Copyright © 2023 Apple Inc. +#include #include #include +#include #include +#include #include #include "doctest/doctest.h" @@ -10,11 +13,71 @@ #include "mlx/mlx.h" using namespace mlx::core; +using namespace std::chrono_literals; std::string get_temp_file(const std::string& name) { return std::filesystem::temp_directory_path().append(name).string(); } +class BlockingFailReader : public io::Reader { + public: + explicit BlockingFailReader(std::vector data) + : data_(std::move(data)), + release_(release_promise_.get_future().share()) {} + + bool is_open() const override { + return true; + } + + bool good() const override { + return true; + } + + size_t tell() override { + return pos_; + } + + void seek(int64_t off, std::ios_base::seekdir way) override { + if (way == std::ios_base::beg) { + pos_ = off; + } else if (way == std::ios_base::cur) { + pos_ += off; + } else { + pos_ = data_.size() + off; + } + } + + void read(char* data, size_t n) override { + std::copy(data_.begin() + pos_, data_.begin() + pos_ + n, data); + pos_ += n; + } + + void read(char*, size_t, size_t) override { + release_.wait(); + throw std::runtime_error("[read] blocked read failed"); + } + + std::string label() const override { + return "blocking fail reader"; + } + + void release() { + release_promise_.set_value(); + } + + private: + std::vector data_; + size_t pos_{0}; + std::promise release_promise_; + std::shared_future release_; +}; + +std::vector read_file(const std::string& path) { + std::ifstream in(path, std::ios::binary); + return std::vector( + std::istreambuf_iterator(in), std::istreambuf_iterator()); +} + TEST_CASE("test save_safetensors") { std::string file_path = get_temp_file("test_arr.safetensors"); auto map = std::unordered_map(); @@ -342,3 +405,84 @@ TEST_CASE("test single array serialization") { CHECK(array_equal(a, b).item()); } } + +TEST_CASE("test load propagates CPU read errors") { + std::string file_path = get_temp_file("test_truncated_arr.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto a = load(file_path, Device::cpu); + CHECK_THROWS_AS(eval(a), std::runtime_error); +} + +TEST_CASE("test load propagates CPU read errors across streams") { + auto good_stream = new_stream(Device::cpu); + auto bad_stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto good = arange(4, float32, good_stream); + auto bad = load(file_path, bad_stream); + CHECK_THROWS_AS(eval(good, bad), std::runtime_error); +} + +TEST_CASE("test async load propagates CPU read errors across streams") { + auto consumer_stream = new_stream(Device::cpu); + auto producer_stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_async_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto bad = load(file_path, producer_stream); + auto out = add(bad, array(1.0f), consumer_stream); + async_eval(out); + CHECK_THROWS_AS(out.wait(), std::runtime_error); +} + +TEST_CASE("test async load error poisons same-stream dependent eval") { + auto stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_truncated_arr_same_stream.npy"); + auto expected = arange(16, float32, Device::cpu); + save(file_path, expected); + std::filesystem::resize_file( + file_path, std::filesystem::file_size(file_path) - expected.nbytes()); + + auto bad = load(file_path, stream); + async_eval(bad); + for (int i = 0; i < 100 && !bad.event().error(); ++i) { + std::this_thread::sleep_for(10ms); + } + REQUIRE(bad.event().error()); + + auto out = add(bad, array(1.0f), stream); + async_eval(out); + CHECK_THROWS_AS(out.wait(), std::runtime_error); +} + +TEST_CASE("test async load error poisons pending same-stream dependent eval") { + auto stream = new_stream(Device::cpu); + + std::string file_path = get_temp_file("test_arr_blocking_reader.npy"); + save(file_path, arange(16, float32, Device::cpu)); + auto reader = std::make_shared(read_file(file_path)); + + auto bad = load(reader, stream); + async_eval(bad); + + auto out = add(bad, array(1.0f), stream); + async_eval(out); + + reader->release(); + CHECK_THROWS_AS(out.wait(), std::runtime_error); + CHECK_THROWS_AS(bad.wait(), std::runtime_error); +} diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index 3a8f7e86b7..2e6d991176 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -1,11 +1,16 @@ // Copyright © 2023 Apple Inc. +#include +#include +#include + #include "doctest/doctest.h" #include "mlx/mlx.h" #include "mlx/scheduler.h" using namespace mlx::core; +using namespace std::chrono_literals; TEST_CASE("test stream management") { auto s1 = default_stream(default_device()); @@ -196,6 +201,44 @@ TEST_CASE("test asynchronous launch") { CHECK_EQ(x, 10); } +TEST_CASE("test event error wakes waiters") { + auto e = Event{default_stream(Device::cpu)}; + e.set_value(1); + e.signal(default_stream(Device::cpu)); + synchronize(default_stream(Device::cpu)); + + e.set_value(2); + + auto waiter = std::async(std::launch::async, [e]() mutable { e.wait(); }); + + std::this_thread::sleep_for(10ms); + e.set_error(std::make_exception_ptr(std::runtime_error("test error"))); + + if (waiter.wait_for(1s) == std::future_status::timeout) { + e.signal(default_stream(Device::cpu)); + FAIL("event waiter was not woken by set_error"); + } + CHECK_THROWS_AS(waiter.get(), std::runtime_error); +} + +TEST_CASE("test event error leaves event signaled after consumption") { + auto e = Event{default_stream(Device::cpu)}; + e.set_value(1); + e.signal(default_stream(Device::cpu)); + synchronize(default_stream(Device::cpu)); + + e.set_value(2); + e.set_error(std::make_exception_ptr(std::runtime_error("test error"))); + CHECK_THROWS_AS(e.wait(), std::runtime_error); + + auto waiter = std::async(std::launch::async, [e]() mutable { e.wait(); }); + if (waiter.wait_for(1s) == std::future_status::timeout) { + e.signal(default_stream(Device::cpu)); + FAIL("event was not signaled after error was consumed"); + } + CHECK_NOTHROW(waiter.get()); +} + TEST_CASE("test stream placement") { auto s1 = default_stream(Device::cpu); auto s2 = new_stream(Device::cpu);