Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlx/backend/common/load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <algorithm>
#include <utility>

#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"

namespace {

Expand Down Expand Up @@ -51,7 +51,8 @@ void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
}
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
cpu::get_command_encoder(stream()).dispatch(
[fut = std::move(fut)]() { fut.get(); });
}

} // namespace mlx::core
64 changes: 59 additions & 5 deletions mlx/backend/cpu/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#pragma once

#include <exception>
#include <unordered_map>

#include "mlx/array.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"

namespace mlx::core::cpu {
Expand Down Expand Up @@ -36,28 +38,80 @@ struct MLX_API CommandEncoder {
std::make_move_iterator(arrays.end()));
}

void set_error_event(Event event) {
event_ = std::move(event);
}

std::vector<array>& temporaries() {
return temporaries_;
}

template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
dispatch_impl(true, std::forward<F>(f), std::forward<Args>(args)...);
}

template <class F, class... Args>
void dispatch_unchecked(F&& f, Args&&... args) {
dispatch_impl(false, std::forward<F>(f), std::forward<Args>(args)...);
}

private:
template <class F, class... Args>
void dispatch_impl(bool skip_on_error, F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
auto event = event_;
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
auto task_wrap = [s = stream_,
event = std::move(event),
skip_on_error,
task = std::move(task)]() mutable {
struct CompletionNotifier {
Stream stream;
~CompletionNotifier() {
scheduler::notify_task_completion(stream);
}
} completion{s};
if (skip_on_error && event.valid() && event.error()) {
return;
}
try {
task();
} catch (...) {
if (event.valid()) {
event.set_error(std::current_exception());
} else {
throw;
}
}
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
scheduler::enqueue(
stream_,
[event = std::move(event),
skip_on_error,
task = std::move(task)]() mutable {
if (skip_on_error && event.valid() && event.error()) {
return;
}
try {
task();
} catch (...) {
if (event.valid()) {
event.set_error(std::current_exception());
} else {
throw;
}
}
});
}
}

private:
Stream stream_;
Event event_;
std::vector<array> temporaries_;
int num_ops_{0};
};
Expand Down
27 changes: 27 additions & 0 deletions mlx/backend/cpu/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@

namespace mlx::core::cpu {

void set_error_event(Stream s, Event event) {
get_command_encoder(s).set_error_event(std::move(event));
}

void clear_error_event(Stream s) {
auto& encoders = get_command_encoders();
auto it = encoders.find(s.index);
if (it != encoders.end()) {
it->second.set_error_event(Event{});
return;
}

auto& global_encoders = get_global_command_encoders();
it = global_encoders.find(s.index);
if (it != global_encoders.end()) {
it->second.set_error_event(Event{});
}
}

void check_error_event(Stream s, Event event) {
get_command_encoder(s).dispatch([event = std::move(event)]() {
if (auto error = event.error()) {
std::rethrow_exception(error);
}
});
}

void new_stream(Stream s) {
auto& encoders = get_command_encoders();
encoders.try_emplace(s.index, s);
Expand Down
3 changes: 3 additions & 0 deletions mlx/backend/cpu/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace mlx::core::cpu {

void set_error_event(Stream s, Event event);
void clear_error_event(Stream s);
void check_error_event(Stream s, Event event);
void new_stream(Stream s);
void new_thread_unsafe_stream(Stream s);
void eval(array& arr);
Expand Down
83 changes: 72 additions & 11 deletions mlx/backend/cuda/event.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/gpu/device_info.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"

#include <map>
#include <mutex>
#include <vector>

#include <nvtx3/nvtx3.hpp>
Expand Down Expand Up @@ -131,7 +132,7 @@ class CopyableCudaEvent {

void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
cpu::get_command_encoder(s).dispatch([*this]() mutable {
check_recorded();
event_->wait();
});
Expand Down Expand Up @@ -265,7 +266,8 @@ void AtomicEvent::wait(cudaStream_t stream, uint32_t value) {
void AtomicEvent::wait(Stream s, uint32_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
cpu::get_command_encoder(s).dispatch(
[*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.commit();
Expand All @@ -292,8 +294,8 @@ void AtomicEvent::signal(Stream s, uint32_t value) {
if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
scheduler::enqueue(
s, [*this, value]() mutable { signal(signal_stream(), value); });
cpu::get_command_encoder(s).dispatch_unchecked(
[*this, value]() mutable { signal(signal_stream(), value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.commit();
Expand Down Expand Up @@ -339,6 +341,8 @@ struct EventImpl {
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::AtomicEvent> atomic;
std::exception_ptr error;
std::mutex mtx;

bool is_created() const {
return cuda || atomic;
Expand All @@ -356,6 +360,35 @@ struct EventImpl {
cuda = std::make_unique<cu::CopyableCudaEvent>(d);
}
}

void set_error(std::exception_ptr err, uint64_t value) {
{
std::lock_guard<std::mutex> lk(mtx);
if (!error) {
error = std::move(err);
}
}
if (atomic) {
atomic->signal(value);
}
}

void check_error() {
std::exception_ptr err;
{
std::lock_guard<std::mutex> lk(mtx);
err = std::move(error);
error = nullptr;
}
if (err) {
std::rethrow_exception(err);
}
}

std::exception_ptr get_error() {
std::lock_guard<std::mutex> lk(mtx);
return error;
}
};

} // namespace
Expand All @@ -367,24 +400,43 @@ Event::Event(Stream s) : stream_(s) {

void Event::wait() {
auto* event = static_cast<EventImpl*>(event_.get());
event->check_error();
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait();
} else {
event->atomic->wait(value());
}
event->check_error();
CHECK_CUDA_ERROR(cudaPeekAtLastError());
}

void Event::wait(Stream s) {
auto* event = static_cast<EventImpl*>(event_.get());
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
if (s.device == mlx::core::Device::cpu) {
cpu::get_command_encoder(s).dispatch(
[event_ = event_, value = value()]() mutable {
auto* event = static_cast<EventImpl*>(event_.get());
event->check_error();
assert(event->is_created());
if (event->cuda) {
assert(value == 1);
event->cuda->wait();
} else {
event->atomic->wait(value);
}
event->check_error();
});
} else {
event->atomic->wait(s, value());
auto* event = static_cast<EventImpl*>(event_.get());
event->check_error();
assert(event->is_created());
if (event->cuda) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->atomic->wait(s, value());
}
}
}

Expand All @@ -401,6 +453,7 @@ void Event::signal(Stream s) {

bool Event::is_signaled() const {
auto* event = static_cast<EventImpl*>(event_.get());
event->check_error();
if (!event->is_created()) {
return false;
}
Expand All @@ -412,4 +465,12 @@ bool Event::is_signaled() const {
}
}

void Event::set_error(std::exception_ptr error) {
static_cast<EventImpl*>(event_.get())->set_error(std::move(error), value());
}

std::exception_ptr Event::error() const {
return static_cast<EventImpl*>(event_.get())->get_error();
}

} // namespace mlx::core
Loading