diff --git a/src/cli/self_update.rs b/src/cli/self_update.rs index d23d82bd62..137b5e71e2 100644 --- a/src/cli/self_update.rs +++ b/src/cli/self_update.rs @@ -63,7 +63,7 @@ use crate::{ DistOptions, PartialToolchainDesc, Profile, TargetTuple, ToolchainDesc, download::DownloadCfg, }, - download::download_file, + download::DownloadOptions, errors::RustupError, install::{InstallMethod, UpdateStatus}, process::Process, @@ -1340,7 +1340,10 @@ pub(crate) async fn prepare_update(dl_cfg: &DownloadCfg<'_>) -> Result) -> Result(&release_toml_str) .context("unable to parse rustup release file")?; diff --git a/src/cli/self_update/windows.rs b/src/cli/self_update/windows.rs index 5f3eb5e0bd..07578e3ec1 100644 --- a/src/cli/self_update/windows.rs +++ b/src/cli/self_update/windows.rs @@ -22,8 +22,7 @@ use super::{InstallOpts, install_bins, report_error}; use crate::cli::markdown::md; use crate::config::Cfg; use crate::dist::TargetTuple; -use crate::dist::download::DownloadCfg; -use crate::download::download_file; +use crate::download::DownloadOptions; use crate::process::{ColorableTerminal, Process}; use crate::utils; @@ -271,16 +270,11 @@ pub(crate) async fn try_install_msvc( .context("error creating temp directory")?; let visual_studio = tempdir.path().join("vs_setup.exe"); - let dl_cfg = DownloadCfg::new(cfg); info!("downloading Visual Studio installer"); - download_file( - &visual_studio_url, - &visual_studio, - None, - None, - dl_cfg.process, - ) - .await?; + DownloadOptions::try_from(cfg.process)? + .start(&visual_studio_url, &visual_studio) + .download() + .await?; // Run the installer. Arguments are documented at: // https://docs.microsoft.com/en-us/visualstudio/install/use-command-line-parameters-to-install-visual-studio diff --git a/src/dist/download.rs b/src/dist/download.rs index 39283f7f38..f1d25f10e8 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -15,7 +15,7 @@ use url::Url; use crate::config::Cfg; use crate::dist::manifest::{Manifest, ManifestWithHash}; use crate::dist::{Channel, DEFAULT_DIST_SERVER, ToolchainDesc, temp}; -use crate::download::{download_file, download_file_with_resume, is_network_failure}; +use crate::download::{DownloadOptions, is_network_failure}; use crate::errors::RustupError; use crate::process::Process; use crate::utils; @@ -82,17 +82,13 @@ impl<'a> DownloadCfg<'a> { let partial_file_existed = partial_file_path.exists(); let mut hasher = Sha256::new(); + let download = DownloadOptions::try_from(self.process)? + .start(url, &partial_file_path) + .with_hasher(&mut hasher) + .with_status(status) + .with_resume(); - if let Err(e) = download_file_with_resume( - url, - &partial_file_path, - Some(&mut hasher), - true, - Some(status), - self.process, - ) - .await - { + if let Err(e) = download.download().await { let is_network_failure = is_network_failure(&e); let err = Err(e); return match (partial_file_existed, is_network_failure) { @@ -142,9 +138,10 @@ impl<'a> DownloadCfg<'a> { async fn download_hash(&self, url: &str) -> Result { let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?; let hash_file = self.tmp_cx.new_file()?; - - download_file(&hash_url, &hash_file, None, None, self.process).await?; - + DownloadOptions::try_from(self.process)? + .start(&hash_url, &hash_file) + .download() + .await?; utils::read_file("hash", &hash_file).map(|s| s[0..64].to_owned()) } @@ -267,7 +264,16 @@ impl<'a> DownloadCfg<'a> { let file = self.tmp_cx.new_file_with_ext("", ext)?; let mut hasher = Sha256::new(); - download_file(&url, &file, Some(&mut hasher), status, self.process).await?; + let download = DownloadOptions::try_from(self.process)? + .start(&url, &file) + .with_hasher(&mut hasher); + + let download = match status { + Some(status) => download.with_status(status), + None => download, + }; + + download.download().await?; let actual_hash = faster_hex::hex_string(&hasher.finalize()); if hash != actual_hash { diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index 00691c29c1..232e8eb466 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -22,7 +22,7 @@ use crate::{ prefix::InstallPrefix, temp, }, - download::download_file, + download::DownloadOptions, errors::RustupError, process::TestProcess, test::{ @@ -490,7 +490,10 @@ impl TestContext { // Download the dist manifest and place it into the installation prefix let manifest_url = make_manifest_url(&self.url, &self.toolchain)?; let manifest_file = self.tmp_cx.new_file()?; - download_file(&manifest_url, &manifest_file, None, None, dl_cfg.process).await?; + DownloadOptions::try_from(dl_cfg.process)? + .start(&manifest_url, &manifest_file) + .download() + .await?; let manifest_str = utils::read_file("manifest", &manifest_file)?; let manifest = Manifest::parse(&manifest_str)?; diff --git a/src/download/mod.rs b/src/download/mod.rs index d930d0f998..a8060c665e 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -1,226 +1,185 @@ //! Easy file downloading -use std::fs::remove_file; +use std::cell::RefCell; +use std::fs::{self, OpenOptions, remove_file}; +use std::io::{self, Read, Seek, SeekFrom, Write}; use std::num::NonZero; use std::path::Path; use std::str::FromStr; +#[cfg(feature = "reqwest-rustls-tls")] +use std::sync::Arc; +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +use std::sync::OnceLock; use std::time::Duration; -use anyhow::Context; -#[cfg(any( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") -))] -use anyhow::anyhow; +use anyhow::{Context, anyhow}; +use reqwest::{Client, ClientBuilder, Proxy, header}; +#[cfg(feature = "reqwest-rustls-tls")] +use rustls::crypto::aws_lc_rs; +#[cfg(feature = "reqwest-rustls-tls")] +use rustls_platform_verifier::Verifier; use sha2::Sha256; use thiserror::Error; -use tracing::debug; -use tracing::warn; +use tokio_stream::StreamExt; +use tracing::{debug, warn}; use url::Url; +#[cfg(all(feature = "reqwest-rustls-tls", not(target_os = "android")))] +use crate::anchors::RUSTUP_TRUST_ANCHORS; use crate::{dist::download::DownloadStatus, errors::RustupError, process::Process}; #[cfg(test)] mod tests; -pub(crate) async fn download_file( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - download_file_with_resume(url, path, hasher, false, status, process).await +#[derive(Debug, Clone, Copy)] +pub struct DownloadOptions { + tls: Tls, + timeout: Duration, } -pub(crate) async fn download_file_with_resume( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - match download_file_(url, path, hasher, resume_from_partial, status, process).await { - Ok(_) => Ok(()), - Err(e) => { - if e.downcast_ref::().is_some() { - return Err(e); - } - let is_client_error = match e.downcast_ref::() { - // Specifically treat the bad partial range error as not our - // fault in case it was something odd which happened. - Some(DownloadError::HttpStatus(416)) => false, - Some(DownloadError::HttpStatus(400..=499)) | Some(DownloadError::FileNotFound) => { - true - } - _ => false, - }; - Err(e).with_context(|| { - if is_client_error { - RustupError::DownloadNotExists { - url: url.clone(), - path: path.to_path_buf(), - } - } else { - RustupError::DownloadingFile { - url: url.clone(), - path: path.to_path_buf(), - } - } - }) +impl DownloadOptions { + pub fn start<'a>(&self, url: &'a Url, path: &'a Path) -> Download<'a> { + Download { + url, + path, + hasher: None, + status: None, + resume: false, + options: *self, } } } -pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { - match err.downcast_ref::() { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), - _ => false, - } -} +impl TryFrom<&Process> for DownloadOptions { + type Error = anyhow::Error; -async fn download_file_( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use crate::download::{Backend, Event}; - use sha2::Digest; - use std::cell::RefCell; - - debug!(url = %url, "downloading file"); - let hasher = RefCell::new(hasher); - - // This callback will write the download to disk and optionally - // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { - if let Event::DownloadDataReceived(data) = msg - && let Some(h) = hasher.borrow_mut().as_mut() - { - h.update(data); + fn try_from(process: &Process) -> Result { + let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + if use_rustls == Some(false) { + warn!( + "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, + please file an issue if the default download backend does not work for your use case" + ); } - match msg { - Event::DownloadContentLengthReceived(len) => { - if let Some(status) = status { - status.received_length(len) - } + let tls = match use_rustls { + // If the environment explicitly selects a TLS backend that's unavailable, error out. + #[cfg(not(feature = "reqwest-rustls-tls"))] + Some(true) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" + )); } - Event::DownloadDataReceived(data) => { - if let Some(status) = status { - status.received_data(data.len()) - } + #[cfg(not(feature = "reqwest-native-tls"))] + Some(false) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" + )); } - Event::ResumingPartialDownload => debug!("resuming partial download"), - } - Ok(()) - }; + // Prefer explicit selections before falling back to the default TLS stack. + #[cfg(feature = "reqwest-native-tls")] + Some(false) => Tls::NativeTls, - // Download the file + // The default fallback is `rustls`, which should be used whenever available. + #[cfg(feature = "reqwest-rustls-tls")] + _ => Tls::Rustls, - let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); - if use_rustls == Some(false) { - warn!( - "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, - please file an issue if the default download backend does not work for your use case" - ); - } + // The `rustls` feature is disabled, fall back to `native-tls` instead. + #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] + _ => Tls::NativeTls, + }; - let backend = match use_rustls { - // If the environment explicitly selects a TLS backend that's unavailable, error out. - #[cfg(not(feature = "reqwest-rustls-tls"))] - Some(true) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" - )); - } - #[cfg(not(feature = "reqwest-native-tls"))] - Some(false) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" - )); - } + let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { + Ok(s) => NonZero::from_str(&s) + .context( + "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", + )? + .get(), + Err(_) => 180, + }); - // Prefer explicit selections before falling back to the default TLS stack. - #[cfg(feature = "reqwest-native-tls")] - Some(false) => Backend::NativeTls, - - // The default fallback is `rustls`, which should be used whenever available. - #[cfg(feature = "reqwest-rustls-tls")] - _ => Backend::Rustls, - - // The `rustls` feature is disabled, fall back to `native-tls` instead. - #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - _ => Backend::NativeTls, - }; - - let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { - Ok(s) => NonZero::from_str(&s) - .context( - "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", - )? - .get(), - Err(_) => 180, - }); - - debug!("downloading with reqwest"); - - let res = backend - .download_to_path(url, path, resume_from_partial, Some(callback), timeout) - .await; - - // The notification should only be sent if the download was successful (i.e. didn't timeout) - if let Some(status) = status { - match &res { - Ok(_) => status.finished(), - Err(_) => status.failed(), - }; + Ok(Self { tls, timeout }) } +} - res +pub struct Download<'a> { + url: &'a Url, + path: &'a Path, + hasher: Option>, + status: Option<&'a DownloadStatus>, + resume: bool, + options: DownloadOptions, } -/// User agent header value for HTTP request. -/// See: https://github.com/rust-lang/rustup/issues/2860. -#[cfg(feature = "reqwest-native-tls")] -const REQWEST_DEFAULT_TLS_USER_AGENT: &str = concat!( - "rustup/", - env!("CARGO_PKG_VERSION"), - " (reqwest; default-tls)" -); +impl<'a> Download<'a> { + pub(crate) fn with_hasher(mut self, hasher: &'a mut Sha256) -> Self { + self.hasher = Some(RefCell::new(hasher)); + self + } -#[cfg(feature = "reqwest-rustls-tls")] -const REQWEST_RUSTLS_TLS_USER_AGENT: &str = - concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); + pub(crate) fn with_status(mut self, status: &'a DownloadStatus) -> Self { + self.status = Some(status); + self + } -#[derive(Debug, Copy, Clone)] -enum Backend { - #[cfg(feature = "reqwest-rustls-tls")] - Rustls, - #[cfg(feature = "reqwest-native-tls")] - NativeTls, -} + pub(crate) fn with_resume(mut self) -> Self { + self.resume = true; + self + } -impl Backend { - async fn download_to_path( - self, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, - timeout: Duration, - ) -> anyhow::Result<()> { - let Err(err) = self - .download_impl(url, path, resume_from_partial, callback, timeout) - .await - else { + pub(crate) async fn download(&self) -> anyhow::Result<()> { + match self.download_file_().await { + Ok(_) => Ok(()), + Err(e) => { + if e.downcast_ref::().is_some() { + return Err(e); + } + let is_client_error = match e.downcast_ref::() { + // Specifically treat the bad partial range error as not our + // fault in case it was something odd which happened. + Some(DownloadError::HttpStatus(416)) => false, + Some(DownloadError::HttpStatus(400..=499)) + | Some(DownloadError::FileNotFound) => true, + _ => false, + }; + Err(e).with_context(|| { + if is_client_error { + RustupError::DownloadNotExists { + url: self.url.clone(), + path: self.path.to_path_buf(), + } + } else { + RustupError::DownloadingFile { + url: self.url.clone(), + path: self.path.to_path_buf(), + } + } + }) + } + } + } + + async fn download_file_(&self) -> anyhow::Result<()> { + debug!(url = %self.url, "downloading file"); + + // Download the file + + let res = self.download_to_path().await; + + // The notification should only be sent if the download was successful (i.e. didn't timeout) + if let Some(status) = self.status { + match &res { + Ok(_) => status.finished(), + Err(_) => status.failed(), + }; + } + + res + } + + async fn download_to_path(&self) -> anyhow::Result<()> { + let Err(err) = self.download_impl().await else { return Ok(()); }; @@ -228,8 +187,9 @@ impl Backend { // if there was a network failure from the client side. // It may be worth looking for other cases where removal is also not desired. Err( - if !(resume_from_partial && is_network_failure(&err)) - && let Err(file_err) = remove_file(path).context("cleaning up cached downloads") + if !(self.resume && is_network_failure(&err)) + && let Err(file_err) = + remove_file(self.path).context("cleaning up cached downloads") { file_err.context(err) } else { @@ -238,26 +198,14 @@ impl Backend { ) } - async fn download_impl( - self, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, - timeout: Duration, - ) -> anyhow::Result<()> { - use std::cell::RefCell; - use std::fs::OpenOptions; - use std::io::{Read, Seek, SeekFrom, Write}; - - let (file, resume_from) = if resume_from_partial { + async fn download_impl(&self) -> anyhow::Result<()> { + let (mut file, resume_from) = if self.resume { // TODO: blocking call - let possible_partial = OpenOptions::new().read(true).open(path); + let possible_partial = OpenOptions::new().read(true).open(self.path); let downloaded_so_far = if let Ok(mut partial) = possible_partial { - if let Some(cb) = callback { - cb(Event::ResumingPartialDownload)?; - + debug!("resuming partial download"); + if let Some(status) = self.status { let mut buf = vec![0; 32768]; let mut downloaded_so_far = 0; loop { @@ -266,7 +214,7 @@ impl Backend { if n == 0 { break; } - cb(Event::DownloadDataReceived(&buf[..n]))?; + status.received_data(n); } downloaded_so_far @@ -283,7 +231,7 @@ impl Backend { .write(true) .create(true) .truncate(false) - .open(path) + .open(self.path) .context("error opening file for download")?; possible_partial.seek(SeekFrom::End(0))?; @@ -295,106 +243,73 @@ impl Backend { .write(true) .create(true) .truncate(true) - .open(path) + .open(self.path) .context("error creating file for download")?, 0, ) }; - let file = RefCell::new(file); - - // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. - self.download(url, resume_from, timeout, &|event| { - if let Event::DownloadDataReceived(data) = event { - file.borrow_mut() - .write_all(data) - .context("unable to write download to disk")?; - } - match callback { - Some(cb) => cb(event), - None => Ok(()), - } - }) - .await?; + let client = match self.options.tls { + #[cfg(feature = "reqwest-rustls-tls")] + Tls::Rustls => rustls_client(self.options.timeout)?, + #[cfg(feature = "reqwest-native-tls")] + Tls::NativeTls => native_tls_client(self.options.timeout)?, + }; - file.borrow_mut() - .sync_data() + self.execute(resume_from, &mut file, client).await?; + file.sync_data() .context("unable to sync download to disk")?; Ok::<(), anyhow::Error>(()) } - #[cfg_attr( - all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") - ), - allow(unused_variables) - )] - async fn download( - self, - url: &Url, + async fn execute( + &self, resume_from: u64, - timeout: Duration, - callback: DownloadCallback<'_>, + file: &mut fs::File, + client: &Client, ) -> anyhow::Result<()> { - let client = match self { - #[cfg(feature = "reqwest-rustls-tls")] - Self::Rustls => reqwest_be::rustls_client(timeout)?, - #[cfg(feature = "reqwest-native-tls")] - Self::NativeTls => reqwest_be::native_tls_client(timeout)?, - }; - - reqwest_be::download(url, resume_from, callback, client).await - } -} - -#[derive(Debug, Copy, Clone)] -enum Event<'a> { - ResumingPartialDownload, - /// Received the Content-Length of the to-be downloaded data. - DownloadContentLengthReceived(u64), - /// Received some data. - DownloadDataReceived(&'a [u8]), -} + // Short-circuit reqwest for the "file:" URL scheme + // The file scheme is mostly for use by tests to mock the dist server + let url = self.url; + if url.scheme() == "file" { + let src = url + .to_file_path() + .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; + if !src.is_file() { + // Because some of rustup's logic depends on checking + // the error when a downloaded file doesn't exist, make + // the file case return the same error value as the + // network case. + return Err(anyhow!(DownloadError::FileNotFound)); + } -type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>; + let mut f = fs::File::open(src).context("unable to open downloaded file")?; + Seek::seek(&mut f, SeekFrom::Start(resume_from))?; -#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] -mod reqwest_be { - #[cfg(feature = "reqwest-rustls-tls")] - use std::sync::Arc; - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use std::sync::OnceLock; - use std::{io, time::Duration}; + let mut buffer = vec![0u8; 0x10000]; + loop { + let bytes_read = Read::read(&mut f, &mut buffer)?; + if bytes_read == 0 { + break; + } - #[cfg(all(feature = "reqwest-rustls-tls", not(target_os = "android")))] - use crate::anchors::RUSTUP_TRUST_ANCHORS; - use anyhow::{Context, anyhow}; - use reqwest::{Client, ClientBuilder, Proxy, Response, header}; - #[cfg(feature = "reqwest-rustls-tls")] - use rustls::crypto::aws_lc_rs; - #[cfg(feature = "reqwest-rustls-tls")] - use rustls_platform_verifier::Verifier; - use tokio_stream::StreamExt; - use url::Url; + file.write_all(&buffer[..bytes_read]) + .context("unable to write download to disk")?; - use super::{DownloadError, Event}; + if let Some(status) = self.status { + status.received_data(bytes_read); + } + } - pub(super) async fn download( - url: &Url, - resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, - client: &Client, - ) -> anyhow::Result<()> { - // Short-circuit reqwest for the "file:" URL scheme - if download_from_file_url(url, resume_from, callback)? { return Ok(()); } - let res = request(url, resume_from, client) - .await - .context("error downloading file")?; + let mut req = client.get(url.as_str()); + if resume_from != 0 { + req = req.header(header::RANGE, format!("bytes={resume_from}-")); + } + let res = req.send().await.context("error downloading file")?; // If a download is being resumed, we expect a 206 response; // otherwise, if the server ignored the range header, @@ -407,147 +322,126 @@ mod reqwest_be { if let Some(len) = res.content_length() { let len = len + resume_from; - callback(Event::DownloadContentLengthReceived(len))?; + if let Some(status) = self.status { + status.received_length(len); + } } let mut stream = res.bytes_stream(); while let Some(item) = stream.next().await { let bytes = item.map_err(DownloadError::Reqwest)?; - callback(Event::DownloadDataReceived(&bytes))?; + file.write_all(&bytes) + .context("unable to write download to disk")?; + if let Some(status) = self.status { + status.received_data(bytes.len()); + } } Ok(()) } +} - fn client_generic() -> ClientBuilder { - Client::builder() - // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying - // `hyper` library that causes the `reqwest` client to hang in some cases. - // See for more details. - .pool_max_idle_per_host(0) - .gzip(false) - .proxy(Proxy::custom(env_proxy)) +pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { + match err.downcast_ref::() { + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), + _ => false, } +} - #[cfg(feature = "reqwest-rustls-tls")] - pub(super) fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { - // If the client is already initialized, the passed timeout is ignored. - if let Some(client) = CLIENT_RUSTLS_TLS.get() { - return Ok(client); - } +/// User agent header value for HTTP request. +/// See: https://github.com/rust-lang/rustup/issues/2860. +#[cfg(feature = "reqwest-native-tls")] +const REQWEST_DEFAULT_TLS_USER_AGENT: &str = concat!( + "rustup/", + env!("CARGO_PKG_VERSION"), + " (reqwest; default-tls)" +); - let provider = Arc::new(aws_lc_rs::default_provider()); - #[cfg(not(target_os = "android"))] - let result = - Verifier::new_with_extra_roots(RUSTUP_TRUST_ANCHORS.iter().cloned(), provider.clone()); - #[cfg(target_os = "android")] - let result = Verifier::new(provider.clone()); - let verifier = result.map_err(|err| { - DownloadError::Message(format!("failed to initialize platform verifier: {err}")) - })?; - - let mut tls_config = rustls::ClientConfig::builder_with_provider(provider) - .with_safe_default_protocol_versions() - .unwrap() - .dangerous() // We're using a rustls verifier, so it's okay - .with_custom_certificate_verifier(Arc::new(verifier)) - .with_no_client_auth(); - tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - let client = client_generic() - .read_timeout(timeout) - .use_preconfigured_tls(tls_config) - .user_agent(super::REQWEST_RUSTLS_TLS_USER_AGENT) - .build() - .map_err(DownloadError::Reqwest)?; - - let _ = CLIENT_RUSTLS_TLS.set(client); - // "The cell is guaranteed to contain a value when `set` returns, though not necessarily - // the one provided." - Ok(CLIENT_RUSTLS_TLS.get().unwrap()) - } +#[cfg(feature = "reqwest-rustls-tls")] +const REQWEST_RUSTLS_TLS_USER_AGENT: &str = + concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); +#[derive(Debug, Copy, Clone)] +enum Tls { #[cfg(feature = "reqwest-rustls-tls")] - static CLIENT_RUSTLS_TLS: OnceLock = OnceLock::new(); - + Rustls, #[cfg(feature = "reqwest-native-tls")] - pub(super) fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { - // If the client is already initialized, the passed timeout is ignored. - if let Some(client) = CLIENT_NATIVE_TLS.get() { - return Ok(client); - } - - let client = client_generic() - .read_timeout(timeout) - .user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT) - .build() - .map_err(DownloadError::Reqwest)?; - - let _ = CLIENT_NATIVE_TLS.set(client); - // "The cell is guaranteed to contain a value when `set` returns, though not necessarily - // the one provided." - Ok(CLIENT_NATIVE_TLS.get().unwrap()) - } + NativeTls, +} - #[cfg(feature = "reqwest-native-tls")] - static CLIENT_NATIVE_TLS: OnceLock = OnceLock::new(); +fn client_generic() -> ClientBuilder { + Client::builder() + // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying + // `hyper` library that causes the `reqwest` client to hang in some cases. + // See for more details. + .pool_max_idle_per_host(0) + .gzip(false) + .proxy(Proxy::custom(|url| env_proxy::for_url(url).to_url())) +} - fn env_proxy(url: &Url) -> Option { - env_proxy::for_url(url).to_url() +#[cfg(feature = "reqwest-rustls-tls")] +fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { + // If the client is already initialized, the passed timeout is ignored. + if let Some(client) = CLIENT_RUSTLS_TLS.get() { + return Ok(client); } - async fn request( - url: &Url, - resume_from: u64, - client: &Client, - ) -> Result { - let mut req = client.get(url.as_str()); + let provider = Arc::new(aws_lc_rs::default_provider()); + #[cfg(not(target_os = "android"))] + let result = + Verifier::new_with_extra_roots(RUSTUP_TRUST_ANCHORS.iter().cloned(), provider.clone()); + #[cfg(target_os = "android")] + let result = Verifier::new(provider.clone()); + let verifier = result.map_err(|err| { + DownloadError::Message(format!("failed to initialize platform verifier: {err}")) + })?; + + let mut tls_config = rustls::ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() // We're using a rustls verifier, so it's okay + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let client = client_generic() + .read_timeout(timeout) + .use_preconfigured_tls(tls_config) + .user_agent(REQWEST_RUSTLS_TLS_USER_AGENT) + .build() + .map_err(DownloadError::Reqwest)?; + + let _ = CLIENT_RUSTLS_TLS.set(client); + // "The cell is guaranteed to contain a value when `set` returns, though not necessarily + // the one provided." + Ok(CLIENT_RUSTLS_TLS.get().unwrap()) +} - if resume_from != 0 { - req = req.header(header::RANGE, format!("bytes={resume_from}-")); - } +#[cfg(feature = "reqwest-rustls-tls")] +static CLIENT_RUSTLS_TLS: OnceLock = OnceLock::new(); - Ok(req.send().await?) +#[cfg(feature = "reqwest-native-tls")] +fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { + // If the client is already initialized, the passed timeout is ignored. + if let Some(client) = CLIENT_NATIVE_TLS.get() { + return Ok(client); } - fn download_from_file_url( - url: &Url, - resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, - ) -> anyhow::Result { - use std::fs; + let client = client_generic() + .read_timeout(timeout) + .user_agent(REQWEST_DEFAULT_TLS_USER_AGENT) + .build() + .map_err(DownloadError::Reqwest)?; - // The file scheme is mostly for use by tests to mock the dist server - if url.scheme() == "file" { - let src = url - .to_file_path() - .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; - if !src.is_file() { - // Because some of rustup's logic depends on checking - // the error when a downloaded file doesn't exist, make - // the file case return the same error value as the - // network case. - return Err(anyhow!(DownloadError::FileNotFound)); - } - - let mut f = fs::File::open(src).context("unable to open downloaded file")?; - io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?; - - let mut buffer = vec![0u8; 0x10000]; - loop { - let bytes_read = io::Read::read(&mut f, &mut buffer)?; - if bytes_read == 0 { - break; - } - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; - } - - Ok(true) - } else { - Ok(false) - } - } + let _ = CLIENT_NATIVE_TLS.set(client); + // "The cell is guaranteed to contain a value when `set` returns, though not necessarily + // the one provided." + Ok(CLIENT_NATIVE_TLS.get().unwrap()) } +#[cfg(feature = "reqwest-native-tls")] +static CLIENT_NATIVE_TLS: OnceLock = OnceLock::new(); + #[derive(Debug, Error)] enum DownloadError { #[error("http request returned an unsuccessful status code: {0}")] @@ -557,7 +451,7 @@ enum DownloadError { #[error("{0}")] Message(String), #[error(transparent)] - IoError(#[from] std::io::Error), + IoError(#[from] io::Error), #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[error(transparent)] Reqwest(#[from] ::reqwest::Error), diff --git a/src/download/tests.rs b/src/download/tests.rs index 12d4760b9f..00675fe18a 100644 --- a/src/download/tests.rs +++ b/src/download/tests.rs @@ -20,8 +20,7 @@ mod reqwest { use std::env::set_var; use std::error::Error; use std::net::TcpListener; - use std::sync::Mutex; - use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; use std::time::Duration; @@ -30,13 +29,18 @@ mod reqwest { use url::Url; use super::{scrub_env, serve_file, tmp_dir, write_file}; - use crate::download::{Backend, Event}; + use crate::download::{DownloadOptions, Tls}; + + const OPTIONS: DownloadOptions = DownloadOptions { + tls: DOWNLOAD_BACKEND, + timeout: Duration::from_secs(180), + }; #[cfg(feature = "reqwest-rustls-tls")] - const DOWNLOAD_BACKEND: Backend = Backend::Rustls; + const DOWNLOAD_BACKEND: Tls = Tls::Rustls; #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - const DOWNLOAD_BACKEND: Backend = Backend::NativeTls; + const DOWNLOAD_BACKEND: Tls = Tls::NativeTls; // Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy #[tokio::test] @@ -110,69 +114,13 @@ mod reqwest { write_file(&target_path, "123"); let from_url = Url::from_file_path(&from_path).unwrap(); - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - None, - Duration::from_secs(180), - ) - .await - .expect("Test download failed"); - - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); - } - - #[tokio::test] - async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { - let _guard = scrub_env().await; - let tmpdir = tmp_dir(); - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let addr = serve_file(b"xxx45".to_vec(), true); - - let from_url = format!("http://{addr}").parse().unwrap(); - - let callback_partial = AtomicBool::new(false); - let callback_len = Mutex::new(None); - let received_in_callback = Mutex::new(Vec::new()); - - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); - } - } - } - - Ok(()) - }), - Duration::from_secs(180), - ) + OPTIONS + .start(&from_url, &target_path) + .with_resume() + .download_to_path() .await .expect("Test download failed"); - assert!(callback_partial.into_inner()); - assert_eq!(*callback_len.lock().unwrap(), Some(5)); - let observed_bytes = received_in_callback.into_inner().unwrap(); - assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); } @@ -186,14 +134,10 @@ mod reqwest { let addr = serve_file(b"xxx45".to_vec(), false); let from_url = format!("http://{addr}").parse().unwrap(); - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - None, - Duration::from_secs(180), - ) + OPTIONS + .start(&from_url, &target_path) + .with_resume() + .download_to_path() .await .expect_err("download should fail if server ignores range"); @@ -211,10 +155,15 @@ mod reqwest { write_file(&target_path, "123"); let from_url = "http://240.0.0.0:1080".parse().unwrap(); - DOWNLOAD_BACKEND - .download_to_path(&from_url, &target_path, true, None, Duration::from_secs(1)) - .await - .expect_err("download should fail with a connect error"); + DownloadOptions { + tls: DOWNLOAD_BACKEND, + timeout: Duration::from_secs(1), + } + .start(&from_url, &target_path) + .with_resume() + .download_to_path() + .await + .expect_err("download should fail with a connect error"); assert!(target_path.exists(), "partial file should not be deleted"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "123");