Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion coroio/backends/iocp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class TIOCp: public TPollerBase {
// Allocator to avoid dynamic memory allocation for each IOCP event structure.
TArenaAllocator<TIO> Allocator_;
std::vector<OVERLAPPED_ENTRY> Entries_;
std::queue<int> Results_;
};

}
1 change: 0 additions & 1 deletion coroio/backends/uring.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ class TUring: public TPollerBase {
int RingFd_; ///< File descriptor for the io_uring.
int EpollFd_; ///< Epoll file descriptor (for integration with epoll).
struct io_uring Ring_; ///< The io_uring structure.
std::queue<int> Results_; ///< Queue of results for completed operations.
std::vector<char> Buffer_; ///< Buffer used for internal I/O operations.
};

Expand Down
16 changes: 16 additions & 0 deletions coroio/corochain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ struct TFutureBase {
Coro.promise().Caller = caller;
}

void detach() {
if (Coro) {
Coro.promise().Caller = std::noop_coroutine();
}
}

using promise_type = TPromise<T>;

protected:
Expand Down Expand Up @@ -363,6 +369,11 @@ TFuture<T> Any(std::vector<TFuture<T>>&& futures) {
f.await_suspend(self);
}
co_await std::suspend_always();
for (auto& f : all) {
if (!f.done()) {
f.detach();
}
}
co_return std::find_if(all.begin(), all.end(), [](auto& f) { return f.done(); })->await_resume();
}

Expand All @@ -384,6 +395,11 @@ inline TFuture<void> Any(std::vector<TFuture<void>>&& futures) {
f.await_suspend(self);
}
co_await std::suspend_always();
for (auto& f : all) {
if (!f.done()) {
f.detach();
}
}
co_return;
}

Expand Down
45 changes: 33 additions & 12 deletions coroio/poller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <chrono>
#include <iostream>
#include <unordered_set>
#include <vector>
#include <map>
#include <queue>
Expand Down Expand Up @@ -120,21 +121,17 @@ class TPollerBase {
* @param fd The file descriptor.
*/
void RemoveEvent(int fd) {
// TODO: resume waiting coroutines here
MaxFd_ = std::max(MaxFd_, fd);
Changes_.emplace_back(TEvent{fd, TEvent::READ|TEvent::WRITE|TEvent::RHUP, {}});
if (!ReadyEvents_.empty()) {
RemovedFdsCurrentLoop_.insert(fd);
}
}
/**
* @brief No-op placeholder for future cleanup by handle.
*
* Intended to be called by destructors of unfinished futures so that
* pending waits can be unregistered. Currently unimplemented.
*
* @param h The coroutine handle (unused).
*/
void RemoveEvent(THandle /*h*/) {
// TODO: Add new vector for this type of removing
// Will be called in destuctor of unfinished futures

void RemoveReadyHandle(THandle h) {
if (!ReadyEvents_.empty()) {
RemovedHandlesCurrentLoop_.insert(h.address());
}
}
/**
* @brief Suspends execution until the specified time.
Expand Down Expand Up @@ -254,8 +251,27 @@ class TPollerBase {
*/
void WakeupReadyHandles() {
for (auto&& ev : ReadyEvents_) {
if (!ev.Handle) {
continue;
}
if (!RemovedFdsCurrentLoop_.empty()) [[unlikely]] {
if (ev.Fd >= 0 && RemovedFdsCurrentLoop_.count(ev.Fd)) {
continue;
}
}
if (!RemovedHandlesCurrentLoop_.empty()) [[unlikely]] {
if (RemovedHandlesCurrentLoop_.count(ev.Handle.address())) {
RemovedHandlesCurrentLoop_.erase(ev.Handle.address());
if (ev.Fd < 0 && !Results_.empty()) {
Results_.pop();
}
continue;
}
}
Wakeup(std::move(ev));
}
RemovedFdsCurrentLoop_.clear();
RemovedHandlesCurrentLoop_.clear();
}
/**
* @brief Sets the maximum polling duration.
Expand Down Expand Up @@ -303,6 +319,8 @@ class TPollerBase {
ReadyEvents_.clear();
Changes_.clear();
MaxFd_ = 0;
RemovedFdsCurrentLoop_.clear();
RemovedHandlesCurrentLoop_.clear();
}
/**
* @brief Processes scheduled timers.
Expand Down Expand Up @@ -333,6 +351,9 @@ class TPollerBase {
int MaxFd_ = 0; ///< Highest file descriptor in use.
std::vector<TEvent> Changes_; ///< Pending changes (registered events).
std::vector<TEvent> ReadyEvents_; ///< Events ready to wake up their coroutines.
std::queue<int> Results_; ///< Results queue for uring/IOCP completions (shared to allow mid-loop discard).
std::unordered_set<int> RemovedFdsCurrentLoop_; ///< Fds cancelled mid-loop (fd-based backends).
std::unordered_set<void*> RemovedHandlesCurrentLoop_; ///< Handles cancelled mid-loop (uring/IOCP).
unsigned TimerId_ = 0; ///< Counter for generating unique timer IDs.
std::priority_queue<TTimer> Timers_; ///< Priority queue for scheduled timers.
TTime LastTimersProcessTime_; ///< Last time timers were processed.
Expand Down
50 changes: 42 additions & 8 deletions coroio/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,12 @@ class TPollerDrivenSocket: public TSocket
struct TAwaitable {
bool await_ready() const { return false; }
void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Accept(fd, reinterpret_cast<sockaddr*>(&addr[0]), &len, h);
}

TPollerDrivenSocket<T> await_resume() {
handle_ = {};
int clientfd = poller->Result();
if (clientfd < 0) {
throw std::system_error(-clientfd, std::generic_category(), "accept");
Expand All @@ -585,8 +587,13 @@ class TPollerDrivenSocket: public TSocket
return TPollerDrivenSocket<T>{TAddress{reinterpret_cast<sockaddr*>(&addr[0]), len}, clientfd, *poller};
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;
std::coroutine_handle<> handle_;

char addr[2*(sizeof(sockaddr_in6)+16)] = {0}; // use additional memory for windows
socklen_t len = static_cast<socklen_t>(sizeof(addr));
Expand Down Expand Up @@ -616,13 +623,15 @@ class TPollerDrivenSocket: public TSocket
bool await_ready() const { return false; }

void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Connect(fd, addr.first, addr.second, h);
if (deadline != TTime::max()) {
timerId = poller->AddTimer(deadline, h);
}
}

void await_resume() {
handle_ = {};
if (deadline != TTime::max() && poller->RemoveTimer(timerId, deadline)) {
poller->Cancel(fd);
throw std::system_error(std::make_error_code(std::errc::timed_out));
Expand All @@ -633,11 +642,16 @@ class TPollerDrivenSocket: public TSocket
}
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;
std::pair<const sockaddr*, int> addr;
TTime deadline;
unsigned timerId = 0;
std::coroutine_handle<> handle_;
};
return TAwaitable{Poller_, Fd_, RemoteAddr()->RawAddr(), deadline};
}
Expand All @@ -656,10 +670,12 @@ class TPollerDrivenSocket: public TSocket
struct TAwaitable {
bool await_ready() const { return false; }
void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Recv(fd, buf, size, h);
}

auto await_resume() {
handle_ = {};
auto ret = poller->Result();
if (ret < 0) {
#ifdef _WIN32
Expand All @@ -678,11 +694,15 @@ class TPollerDrivenSocket: public TSocket
return ret;
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;

void* buf;
size_t size;
std::coroutine_handle<> handle_;
};

return TAwaitable{Poller_, Fd_, buf, size};
Expand All @@ -702,10 +722,12 @@ class TPollerDrivenSocket: public TSocket
struct TAwaitable {
bool await_ready() const { return false; }
void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Send(fd, buf, size, h);
}

auto await_resume() {
handle_ = {};
auto ret = poller->Result();
if (ret < 0) {
#ifdef _WIN32
Expand All @@ -724,22 +746,24 @@ class TPollerDrivenSocket: public TSocket
return ret;
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;

const void* buf;
size_t size;
std::coroutine_handle<> handle_;
};

return TAwaitable{Poller_, Fd_, buf, size};
}

/// The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
auto WriteSomeYield(const void* buf, size_t size) {
return WriteSome(buf, size);
}

/// The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
auto ReadSomeYield(void* buf, size_t size) {
return ReadSome(buf, size);
}
Expand Down Expand Up @@ -795,10 +819,12 @@ class TPollerDrivenFileHandle: public TFileHandle
struct TAwaitable {
bool await_ready() const { return false; }
void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Read(fd, buf, size, h);
}

auto await_resume() {
handle_ = {};
auto ret = poller->Result();
if (ret < 0) {
#ifdef _WIN32
Expand All @@ -817,11 +843,15 @@ class TPollerDrivenFileHandle: public TFileHandle
return ret;
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;

void* buf;
size_t size;
std::coroutine_handle<> handle_;
};

return TAwaitable{Poller_, Fd_, buf, size};
Expand All @@ -841,10 +871,12 @@ class TPollerDrivenFileHandle: public TFileHandle
struct TAwaitable {
bool await_ready() const { return false; }
void await_suspend(std::coroutine_handle<> h) {
handle_ = h;
poller->Write(fd, buf, size, h);
}

auto await_resume() {
handle_ = {};
auto ret = poller->Result();
if (ret < 0) {
#ifdef _WIN32
Expand All @@ -863,22 +895,24 @@ class TPollerDrivenFileHandle: public TFileHandle
return ret;
}

~TAwaitable() {
if (handle_) { poller->RemoveReadyHandle(handle_); }
}

T* poller;
int fd;

const void* buf;
size_t size;
std::coroutine_handle<> handle_;
};

return TAwaitable{Poller_, Fd_, buf, size};
}

/// The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
auto WriteSomeYield(const void* buf, size_t size) {
return WriteSome(buf, size);
}

/// The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
auto ReadSomeYield(void* buf, size_t size) {
return ReadSome(buf, size);
}
Expand Down
32 changes: 32 additions & 0 deletions tests/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,37 @@ void test_zero_copy_line_splitter_wrap(void**) {
assert_string_equal("cc\n", result.data());
}

template<typename TPoller>
void test_any_simultaneous(void**) {
TLoop<TPoller> loop;

int fds1[2], fds2[2];
assert_int_equal(pipe(fds1), 0);
assert_int_equal(pipe(fds2), 0);
assert_int_equal(write(fds1[1], "x", 1), 1);
assert_int_equal(write(fds2[1], "x", 1), 1);

auto read_one = [](TPoller& poller, int fd) -> TFuture<int> {
typename TPoller::TFileHandle fh(fd, poller);
char buf[1];
co_return co_await fh.ReadSomeYield(buf, 1);
};

TFuture<void> task = [&](TPoller& poller) -> TFuture<void> {
std::vector<TFuture<int>> futures;
futures.push_back(read_one(poller, fds1[0]));
futures.push_back(read_one(poller, fds2[0]));
int r = co_await Any(std::move(futures));
assert_true(r == 1 || r == 2);
}(loop.Poller());

while (!task.done()) {
loop.Step();
}
close(fds1[1]);
close(fds2[1]);
}

void test_self_id(void**) {
void* id;
TFuture<void> h = [](void** id) -> TFuture<void> {
Expand Down Expand Up @@ -1338,6 +1369,7 @@ int main(int argc, char* argv[]) {
ADD_TEST(cmocka_unit_test, test_zero_copy_line_splitter);
ADD_TEST(cmocka_unit_test, test_line_splitter_wrap);
ADD_TEST(cmocka_unit_test, test_zero_copy_line_splitter_wrap);
ADD_TEST(my_unit_poller, test_any_simultaneous);
ADD_TEST(cmocka_unit_test, test_self_id);
ADD_TEST(cmocka_unit_test, test_resolv_nameservers);
ADD_TEST(my_unit_poller, test_listen);
Expand Down
Loading