Skip to content

Commit b0c4fe8

Browse files
committed
Better cancel
Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 0c8dabd commit b0c4fe8

1 file changed

Lines changed: 89 additions & 9 deletions

File tree

vortex-file/src/segments/source.rs

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use futures::FutureExt;
1212
use futures::StreamExt;
1313
use futures::channel::mpsc;
1414
use futures::future;
15+
use futures::future::BoxFuture;
16+
use futures::future::Shared;
1517
use vortex_array::buffer::BufferHandle;
1618
use vortex_buffer::Alignment;
1719
use vortex_buffer::ByteBuffer;
@@ -35,6 +37,8 @@ use crate::read::IoRequestStream;
3537
use crate::read::ReadRequest;
3638
use crate::read::RequestId;
3739

40+
type IoStreamClosed = Shared<BoxFuture<'static, ()>>;
41+
3842
#[derive(Debug)]
3943
pub enum ReadEvent {
4044
Request(ReadRequest),
@@ -67,6 +71,8 @@ pub struct FileSegmentSource {
6771
segments: Arc<[SegmentSpec]>,
6872
/// A queue for sending read request events to the I/O stream.
6973
events: mpsc::UnboundedSender<ReadEvent>,
74+
/// Resolves when the spawned I/O driver stream is dropped.
75+
io_stream_closed: IoStreamClosed,
7076
/// The next read request ID.
7177
next_id: Arc<AtomicUsize>,
7278
}
@@ -108,7 +114,15 @@ impl FileSegmentSource {
108114
)
109115
.boxed();
110116

117+
let (io_stream_closed_send, io_stream_closed_recv) = oneshot::channel();
118+
let io_stream_closed = async move {
119+
let _ = io_stream_closed_recv.into_future().await;
120+
}
121+
.boxed()
122+
.shared();
123+
111124
let drive_fut = async move {
125+
let _io_stream_closed = IoStreamClosedNotifier::new(io_stream_closed_send);
112126
stream
113127
.map(move |req| {
114128
let reader = reader.clone();
@@ -124,11 +138,28 @@ impl FileSegmentSource {
124138
Self {
125139
segments,
126140
events: send,
141+
io_stream_closed,
127142
next_id: Arc::new(AtomicUsize::new(0)),
128143
}
129144
}
130145
}
131146

147+
struct IoStreamClosedNotifier(Option<oneshot::Sender<()>>);
148+
149+
impl IoStreamClosedNotifier {
150+
fn new(send: oneshot::Sender<()>) -> Self {
151+
Self(Some(send))
152+
}
153+
}
154+
155+
impl Drop for IoStreamClosedNotifier {
156+
fn drop(&mut self) {
157+
if let Some(send) = self.0.take() {
158+
drop(send.send(()));
159+
}
160+
}
161+
}
162+
132163
async fn drive_request<R: VortexReadAt>(reader: R, req: IoRequest) {
133164
if req.is_cancelled() {
134165
tracing::debug!(
@@ -183,6 +214,7 @@ impl SegmentSource for FileSegmentSource {
183214
polled: false,
184215
finished: false,
185216
events: self.events.clone(),
217+
io_stream_closed: self.io_stream_closed.clone(),
186218
};
187219

188220
// One allocation: we only box the returned SegmentFuture, not the inner ReadFuture.
@@ -200,6 +232,7 @@ struct ReadFuture {
200232
polled: bool,
201233
finished: bool,
202234
events: mpsc::UnboundedSender<ReadEvent>,
235+
io_stream_closed: IoStreamClosed,
203236
}
204237

205238
impl Future for ReadFuture {
@@ -212,21 +245,34 @@ impl Future for ReadFuture {
212245
// note: we are skipping polled and dropped events for this if the future
213246
// is ready on the first poll, that means this request was completed
214247
// before it was polled, as part of a coalesced request.
215-
Poll::Ready(
248+
return Poll::Ready(
216249
result.unwrap_or_else(|e| {
217250
Err(vortex_err!("ReadRequest dropped by runtime: {e}"))
218251
}),
219-
)
252+
);
220253
}
221-
Poll::Pending if !self.polled => {
222-
self.polled = true;
223-
// Notify the I/O stream that this request has been polled.
224-
match self.events.unbounded_send(ReadEvent::Polled(self.id)) {
225-
Ok(()) => Poll::Pending,
226-
Err(e) => Poll::Ready(Err(vortex_err!("ReadRequest dropped by runtime: {e}"))),
254+
Poll::Pending => {}
255+
}
256+
257+
if self.io_stream_closed.poll_unpin(cx).is_ready() {
258+
self.finished = true;
259+
return Poll::Ready(Err(vortex_err!(
260+
"ReadRequest dropped by runtime: I/O request stream closed"
261+
)));
262+
}
263+
264+
if !self.polled {
265+
self.polled = true;
266+
// Notify the I/O stream that this request has been polled.
267+
match self.events.unbounded_send(ReadEvent::Polled(self.id)) {
268+
Ok(()) => Poll::Pending,
269+
Err(e) => {
270+
self.finished = true;
271+
Poll::Ready(Err(vortex_err!("ReadRequest dropped by runtime: {e}")))
227272
}
228273
}
229-
_ => Poll::Pending,
274+
} else {
275+
Poll::Pending
230276
}
231277
}
232278
}
@@ -319,6 +365,16 @@ mod tests {
319365
use super::*;
320366
use crate::read::CoalescedRequest;
321367

368+
fn io_stream_closed_pair() -> (IoStreamClosedNotifier, IoStreamClosed) {
369+
let (send, recv) = oneshot::channel();
370+
let closed = async move {
371+
let _ = recv.into_future().await;
372+
}
373+
.boxed()
374+
.shared();
375+
(IoStreamClosedNotifier::new(send), closed)
376+
}
377+
322378
#[derive(Clone, Default)]
323379
struct CountingRead {
324380
read_count: Arc<AtomicUsize>,
@@ -436,4 +492,28 @@ mod tests {
436492
assert_eq!(reader.read_count.load(Ordering::Relaxed), 1);
437493
Ok(())
438494
}
495+
496+
#[tokio::test]
497+
async fn read_future_finishes_when_io_stream_closes_after_poll() {
498+
let (_callback, recv) = oneshot::channel();
499+
let (events, _event_recv) = mpsc::unbounded();
500+
let (notifier, io_stream_closed) = io_stream_closed_pair();
501+
502+
let read = ReadFuture {
503+
id: 0,
504+
recv: recv.into_future(),
505+
polled: true,
506+
finished: false,
507+
events,
508+
io_stream_closed,
509+
};
510+
511+
drop(notifier);
512+
513+
let err = read.await.unwrap_err();
514+
assert!(
515+
err.to_string().contains("I/O request stream closed"),
516+
"unexpected error: {err}"
517+
);
518+
}
439519
}

0 commit comments

Comments
 (0)