Skip to content

Commit 92a68b8

Browse files
committed
download: replace callback with direct handling
1 parent 8ad6cec commit 92a68b8

2 files changed

Lines changed: 32 additions & 124 deletions

File tree

src/download/mod.rs

Lines changed: 27 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use reqwest::{Client, ClientBuilder, Proxy, header};
1818
use rustls::crypto::aws_lc_rs;
1919
#[cfg(feature = "reqwest-rustls-tls")]
2020
use rustls_platform_verifier::Verifier;
21-
use sha2::{Digest, Sha256};
21+
use sha2::Sha256;
2222
use thiserror::Error;
2323
use tokio_stream::StreamExt;
2424
use tracing::{debug, warn};
@@ -163,35 +163,9 @@ impl<'a> Download<'a> {
163163
async fn download_file_(&self) -> anyhow::Result<()> {
164164
debug!(url = %self.url, "downloading file");
165165

166-
// This callback will write the download to disk and optionally
167-
// hash the contents, then forward the notification up the stack
168-
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
169-
if let Event::DownloadDataReceived(data) = msg
170-
&& let Some(h) = self.hasher.as_ref()
171-
{
172-
h.borrow_mut().update(data);
173-
}
174-
175-
match msg {
176-
Event::DownloadContentLengthReceived(len) => {
177-
if let Some(status) = self.status {
178-
status.received_length(len)
179-
}
180-
}
181-
Event::DownloadDataReceived(data) => {
182-
if let Some(status) = self.status {
183-
status.received_data(data.len())
184-
}
185-
}
186-
Event::ResumingPartialDownload => debug!("resuming partial download"),
187-
}
188-
189-
Ok(())
190-
};
191-
192166
// Download the file
193167

194-
let res = self.download_to_path(Some(callback)).await;
168+
let res = self.download_to_path().await;
195169

196170
// The notification should only be sent if the download was successful (i.e. didn't timeout)
197171
if let Some(status) = self.status {
@@ -204,8 +178,8 @@ impl<'a> Download<'a> {
204178
res
205179
}
206180

207-
async fn download_to_path(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
208-
let Err(err) = self.download_impl(callback).await else {
181+
async fn download_to_path(&self) -> anyhow::Result<()> {
182+
let Err(err) = self.download_impl().await else {
209183
return Ok(());
210184
};
211185

@@ -224,15 +198,14 @@ impl<'a> Download<'a> {
224198
)
225199
}
226200

227-
async fn download_impl(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
228-
let (file, resume_from) = if self.resume {
201+
async fn download_impl(&self) -> anyhow::Result<()> {
202+
let (mut file, resume_from) = if self.resume {
229203
// TODO: blocking call
230204
let possible_partial = OpenOptions::new().read(true).open(self.path);
231205

232206
let downloaded_so_far = if let Ok(mut partial) = possible_partial {
233-
if let Some(cb) = callback {
234-
cb(Event::ResumingPartialDownload)?;
235-
207+
debug!("resuming partial download");
208+
if let Some(status) = self.status {
236209
let mut buf = vec![0; 32768];
237210
let mut downloaded_so_far = 0;
238211
loop {
@@ -241,7 +214,7 @@ impl<'a> Download<'a> {
241214
if n == 0 {
242215
break;
243216
}
244-
cb(Event::DownloadDataReceived(&buf[..n]))?;
217+
status.received_data(n);
245218
}
246219

247220
downloaded_so_far
@@ -276,7 +249,6 @@ impl<'a> Download<'a> {
276249
)
277250
};
278251

279-
let file = RefCell::new(file);
280252
let client = match self.options.tls {
281253
#[cfg(feature = "reqwest-rustls-tls")]
282254
Tls::Rustls => rustls_client(self.options.timeout)?,
@@ -285,25 +257,9 @@ impl<'a> Download<'a> {
285257
};
286258

287259
// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
288-
self.execute(
289-
resume_from,
290-
&|event| {
291-
if let Event::DownloadDataReceived(data) = event {
292-
file.borrow_mut()
293-
.write_all(data)
294-
.context("unable to write download to disk")?;
295-
}
296-
match callback {
297-
Some(cb) => cb(event),
298-
None => Ok(()),
299-
}
300-
},
301-
client,
302-
)
303-
.await?;
260+
self.execute(resume_from, &mut file, client).await?;
304261

305-
file.borrow_mut()
306-
.sync_data()
262+
file.sync_data()
307263
.context("unable to sync download to disk")?;
308264

309265
Ok::<(), anyhow::Error>(())
@@ -312,7 +268,7 @@ impl<'a> Download<'a> {
312268
async fn execute(
313269
&self,
314270
resume_from: u64,
315-
callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>,
271+
file: &mut fs::File,
316272
client: &Client,
317273
) -> anyhow::Result<()> {
318274
// Short-circuit reqwest for the "file:" URL scheme
@@ -339,7 +295,13 @@ impl<'a> Download<'a> {
339295
if bytes_read == 0 {
340296
break;
341297
}
342-
callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;
298+
299+
file.write_all(&buffer[..bytes_read])
300+
.context("unable to write download to disk")?;
301+
302+
if let Some(status) = self.status {
303+
status.received_data(bytes_read);
304+
}
343305
}
344306

345307
return Ok(());
@@ -362,13 +324,19 @@ impl<'a> Download<'a> {
362324

363325
if let Some(len) = res.content_length() {
364326
let len = len + resume_from;
365-
callback(Event::DownloadContentLengthReceived(len))?;
327+
if let Some(status) = self.status {
328+
status.received_length(len);
329+
}
366330
}
367331

368332
let mut stream = res.bytes_stream();
369333
while let Some(item) = stream.next().await {
370334
let bytes = item.map_err(DownloadError::Reqwest)?;
371-
callback(Event::DownloadDataReceived(&bytes))?;
335+
file.write_all(&bytes)
336+
.context("unable to write download to disk")?;
337+
if let Some(status) = self.status {
338+
status.received_data(bytes.len());
339+
}
372340
}
373341
Ok(())
374342
}
@@ -403,17 +371,6 @@ enum Tls {
403371
NativeTls,
404372
}
405373

406-
#[derive(Debug, Copy, Clone)]
407-
enum Event<'a> {
408-
ResumingPartialDownload,
409-
/// Received the Content-Length of the to-be downloaded data.
410-
DownloadContentLengthReceived(u64),
411-
/// Received some data.
412-
DownloadDataReceived(&'a [u8]),
413-
}
414-
415-
type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>;
416-
417374
fn client_generic() -> ClientBuilder {
418375
Client::builder()
419376
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying

src/download/tests.rs

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ mod reqwest {
2020
use std::env::set_var;
2121
use std::error::Error;
2222
use std::net::TcpListener;
23-
use std::sync::Mutex;
24-
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
23+
use std::sync::atomic::{AtomicUsize, Ordering};
2524
use std::thread;
2625
use std::time::Duration;
2726

@@ -30,7 +29,7 @@ mod reqwest {
3029
use url::Url;
3130

3231
use super::{scrub_env, serve_file, tmp_dir, write_file};
33-
use crate::download::{DownloadOptions, Event, Tls};
32+
use crate::download::{DownloadOptions, Tls};
3433

3534
const OPTIONS: DownloadOptions = DownloadOptions {
3635
tls: DOWNLOAD_BACKEND,
@@ -118,61 +117,13 @@ mod reqwest {
118117
OPTIONS
119118
.start(&from_url, &target_path)
120119
.with_resume()
121-
.download_to_path(None)
120+
.download_to_path()
122121
.await
123122
.expect("Test download failed");
124123

125124
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
126125
}
127126

128-
#[tokio::test]
129-
async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
130-
let _guard = scrub_env().await;
131-
let tmpdir = tmp_dir();
132-
let target_path = tmpdir.path().join("downloaded");
133-
write_file(&target_path, "123");
134-
135-
let addr = serve_file(b"xxx45".to_vec(), true);
136-
137-
let from_url = format!("http://{addr}").parse().unwrap();
138-
139-
let callback_partial = AtomicBool::new(false);
140-
let callback_len = Mutex::new(None);
141-
let received_in_callback = Mutex::new(Vec::new());
142-
143-
OPTIONS
144-
.start(&from_url, &target_path)
145-
.with_resume()
146-
.download_to_path(Some(&|msg| {
147-
match msg {
148-
Event::ResumingPartialDownload => {
149-
assert!(!callback_partial.load(Ordering::SeqCst));
150-
callback_partial.store(true, Ordering::SeqCst);
151-
}
152-
Event::DownloadContentLengthReceived(len) => {
153-
let mut flag = callback_len.lock().unwrap();
154-
assert!(flag.is_none());
155-
*flag = Some(len);
156-
}
157-
Event::DownloadDataReceived(data) => {
158-
for b in data.iter() {
159-
received_in_callback.lock().unwrap().push(*b);
160-
}
161-
}
162-
}
163-
164-
Ok(())
165-
}))
166-
.await
167-
.expect("Test download failed");
168-
169-
assert!(callback_partial.into_inner());
170-
assert_eq!(*callback_len.lock().unwrap(), Some(5));
171-
let observed_bytes = received_in_callback.into_inner().unwrap();
172-
assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']);
173-
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
174-
}
175-
176127
#[tokio::test]
177128
async fn resume_partial_fails_if_server_ignores_range() {
178129
let _guard = scrub_env().await;
@@ -186,7 +137,7 @@ mod reqwest {
186137
OPTIONS
187138
.start(&from_url, &target_path)
188139
.with_resume()
189-
.download_to_path(None)
140+
.download_to_path()
190141
.await
191142
.expect_err("download should fail if server ignores range");
192143

@@ -210,7 +161,7 @@ mod reqwest {
210161
}
211162
.start(&from_url, &target_path)
212163
.with_resume()
213-
.download_to_path(None)
164+
.download_to_path()
214165
.await
215166
.expect_err("download should fail with a connect error");
216167

0 commit comments

Comments
 (0)