diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 322246ca..f8d77720 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -2640,6 +2640,7 @@ dependencies = [ "semver", "serde", "serde_json", + "sha2", "shell-words", "tar", "tauri", @@ -2652,6 +2653,7 @@ dependencies = [ "tauri-plugin-log", "tauri-plugin-opener", "tauri-plugin-process", + "tempfile", "tokio", "tower-http 0.5.2", "urlencoding", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 61df4ad2..7218b96f 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -42,6 +42,7 @@ axum = { version = "0.7", features = ["ws"] } tower-http = { version = "0.5", features = ["cors", "fs"] } futures-util = "0.3" bytes = "1" +sha2 = "0.10" libc = "0.2.180" semver = "1.0" os_info = "3" @@ -51,6 +52,9 @@ shell-words = "1.1.1" maa-framework = { version = "1", features = ["dynamic"] } rust-embed = "8" +[dev-dependencies] +tempfile = "3" + [profile.release] # 保留调试符号以生成 PDB 文件,便于崩溃分析 debug = true diff --git a/src-tauri/src/commands/download.rs b/src-tauri/src/commands/download.rs index fd8c2983..7e5ad3c8 100644 --- a/src-tauri/src/commands/download.rs +++ b/src-tauri/src/commands/download.rs @@ -1,22 +1,20 @@ //! 下载相关命令 //! -//! 提供流式文件下载功能,支持进度回调和取消 +//! 提供流式文件下载功能,支持进度回调、取消和更新包断点续传。 use log::{error, info, warn}; +use reqwest::header::{ACCEPT, AUTHORIZATION, USER_AGENT}; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; - use tauri::Emitter; +use tokio::time::{sleep, Duration}; -use super::types::GitHubRelease; -use reqwest::header::{ACCEPT, AUTHORIZATION, USER_AGENT}; - -use super::types::{DownloadProgressEvent, DownloadResult}; -use super::update::move_to_old_folder; +use super::download_core::{self, DownloadRequest}; +use super::types::{DownloadProgressEvent, DownloadResult, GitHubRelease}; use super::utils::build_user_agent; -/// 进度上报任务的守卫,在函数任意返回路径上都能确保发送停止信号 +/// 进度上报任务的守卫,在函数任意返回路径上都能确保发送停止信号。 struct ProgressEmitterGuard(Option>); impl Drop for ProgressEmitterGuard { @@ -27,42 +25,12 @@ impl Drop for ProgressEmitterGuard { } } -/// 临时文件清理守卫,在函数异常退出时自动删除 .downloading 半成品。 -/// 成功重命名后需调用 `disarm()`,避免 drop 时冗余的 `remove_file`。 -struct TempFileGuard { - path: Option, -} - -impl TempFileGuard { - fn new(path: PathBuf) -> Self { - Self { path: Some(path) } - } - - /// 成功重命名后调用,使 drop 时不再尝试删除(文件已移至目标路径)。 - fn disarm(&mut self) { - self.path = None; - } -} - -impl Drop for TempFileGuard { - fn drop(&mut self) { - if let Some(p) = self.path.take() { - // 必须同步删除,避免竞态条件: - // 如果异步删除,可能在下一次下载创建同名临时文件后才执行,导致误删。 - let _ = std::fs::remove_file(&p); - } - } -} - -/// 全局下载取消标志 +/// 全局下载取消标志。 static DOWNLOAD_CANCELLED: AtomicBool = AtomicBool::new(false); -/// 当前下载的 session ID,用于区分不同的下载任务 +/// 当前下载的 session ID,用于区分不同的下载任务。 static CURRENT_DOWNLOAD_SESSION: AtomicU64 = AtomicU64::new(0); -/// 根据版本号获取 GitHub Release URL -/// -/// 使用 GitHub API 获取指定版本的 Release 信息,支持使用 GitHub PAT 和代理 -/// 解析 GitHub API 返回的 JSON 数据,找到与 target_version 匹配的 release,并返回 URL +/// 根据版本号获取 GitHub Release URL。 #[tauri::command] pub async fn get_github_release_by_version( owner: String, @@ -72,14 +40,11 @@ pub async fn get_github_release_by_version( proxy_url: Option, ) -> Result, String> { let url = format!("https://api.github.com/repos/{}/{}/releases", owner, repo); - - // 构造请求头 let mut client_builder = reqwest::Client::builder() .user_agent("mxu") .timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(3)); - // 添加代理配置(如果提供) if let Some(ref proxy) = proxy_url { if !proxy.is_empty() { info!("[检查更新] 使用代理: {}", proxy); @@ -92,19 +57,19 @@ pub async fn get_github_release_by_version( ) })?; client_builder = client_builder.proxy(reqwest_proxy); + } else { + info!("[下载] 直连(无代理): {}", url); } } let client = client_builder .build() .map_err(|e| format!("创建 HTTP 客户端失败: {}", e))?; - let mut request = client .get(&url) .header(ACCEPT, "application/vnd.github.v3+json") .header(USER_AGENT, "mxu"); - // 添加 PAT 认证(如果提供) if let Some(pat) = github_pat { if !pat.trim().is_empty() { request = request.header(AUTHORIZATION, format!("token {}", pat.trim())); @@ -115,7 +80,6 @@ pub async fn get_github_release_by_version( .send() .await .map_err(|e| format!("请求失败: {}", e))?; - if !response.status().is_success() { return Err(format!("GitHub API 错误: {}", response.status())); } @@ -124,7 +88,6 @@ pub async fn get_github_release_by_version( .json() .await .map_err(|e| format!("解析 JSON 失败: {}", e))?; - let normalize = |v: &str| { v.trim_start_matches(|c| c == 'v' || c == 'V') .to_lowercase() @@ -144,14 +107,7 @@ pub async fn get_github_release_by_version( Ok(None) } -/// 流式下载文件,支持进度回调和取消 -/// -/// 使用 reqwest 进行流式下载,直接写入文件而不经过内存缓冲, -/// 解决 JavaScript 下载大文件时的性能问题 -/// -/// 返回 DownloadResult,包含 session_id 和实际保存路径 -/// 如果检测到重定向后的 URL 或 Content-Disposition 包含正确的文件名, -/// 会使用该文件名保存(替换原始 save_path 的文件名部分) +/// 流式下载文件。更新包传入 `resume_key` 后会保留可验证的半成品供下次续传。 #[tauri::command] pub async fn download_file( app: tauri::AppHandle, @@ -159,34 +115,20 @@ pub async fn download_file( save_path: String, total_size: Option, proxy_url: Option, + resume_key: Option, + sha256: Option, ) -> Result { - use futures_util::StreamExt; - use std::io::Write; - use tokio::time::{sleep, Duration}; - info!("download_file: {} -> {}", url, save_path); - // 生成新的 session ID,使旧下载的进度事件无效 let session_id = CURRENT_DOWNLOAD_SESSION.fetch_add(1, Ordering::SeqCst) + 1; - info!("download_file session_id: {}", session_id); - - // 重置取消标志 DOWNLOAD_CANCELLED.store(false, Ordering::SeqCst); + info!("download_file session_id: {}", session_id); - let save_path_obj = std::path::Path::new(&save_path); - - // 确保目录存在 - if let Some(parent) = save_path_obj.parent() { - std::fs::create_dir_all(parent).map_err(|e| format!("无法创建目录: {}", e))?; - } - - // 构建 HTTP 客户端和请求 let mut client_builder = reqwest::Client::builder() .user_agent(build_user_agent()) - .timeout(std::time::Duration::from_secs(600)) // 10 分钟超时,足够下载大文件但防止无限挂起 + .timeout(std::time::Duration::from_secs(600)) .connect_timeout(std::time::Duration::from_secs(10)); - // 配置代理(如果提供) if let Some(ref proxy) = proxy_url { if !proxy.is_empty() { info!("[下载] 使用代理: {}", proxy); @@ -199,8 +141,6 @@ pub async fn download_file( ) })?; client_builder = client_builder.proxy(reqwest_proxy); - } else { - info!("[下载] 直连(无代理): {}", url); } } else { info!("[下载] 直连(无代理): {}", url); @@ -209,373 +149,123 @@ pub async fn download_file( let client = client_builder .build() .map_err(|e| format!("创建 HTTP 客户端失败: {}", e))?; + let downloaded_shared = Arc::new(AtomicU64::new(0)); + let session_received_shared = Arc::new(AtomicU64::new(0)); + let total_shared = Arc::new(AtomicU64::new(total_size.unwrap_or(0))); + let progress_guard = start_progress_emitter( + app.clone(), + session_id, + downloaded_shared.clone(), + session_received_shared.clone(), + total_shared.clone(), + ); - let response = client - .get(&url) - .send() - .await - .map_err(|e| format!("请求失败: {}", e))?; - - if !response.status().is_success() { - return Err(format!("HTTP 错误: {}", response.status())); - } - - // 尝试从 Content-Disposition header 或最终 URL 提取文件名 - let detected_filename = extract_filename_from_response(&response); - if let Some(ref name) = detected_filename { - info!("[下载] 检测到文件名: {}", name); - } - - // 确定实际保存路径 - let actual_save_path = if let Some(ref filename) = detected_filename { - // 使用检测到的文件名,保持原目录 - if let Some(parent) = save_path_obj.parent() { - parent.join(filename).to_string_lossy().to_string() - } else { - filename.clone() - } - } else { - save_path.clone() - }; - - let actual_save_path_obj = std::path::Path::new(&actual_save_path); - - // 使用包含 session_id 的临时文件名,避免取消后立即重试时新旧任务竞争同一临时文件 - let temp_path = format!("{}.{}.downloading", actual_save_path, session_id); - let mut temp_guard = TempFileGuard::new(PathBuf::from(&temp_path)); - - // 获取文件大小 - let content_length = response.content_length(); - let total = total_size.or(content_length).unwrap_or(0); - - // 有界通道将网络读取与磁盘写入解耦: - // - 下载循环纯异步,不阻塞 runtime,可全速消费 TCP 流 - // - 写入线程用同步 BufWriter,单线程从头跑到尾,避免 tokio::fs 逐次 spawn_blocking 的调度开销 - let (write_tx, write_rx) = tokio::sync::mpsc::channel::(64); + let core_result = download_core::download( + &client, + DownloadRequest { + url, + save_path: PathBuf::from(save_path), + expected_size: total_size.filter(|size| *size > 0), + resume_key, + sha256, + session_id, + }, + downloaded_shared, + session_received_shared, + total_shared, + || { + DOWNLOAD_CANCELLED.load(Ordering::SeqCst) + || CURRENT_DOWNLOAD_SESSION.load(Ordering::SeqCst) != session_id + }, + ) + .await; - let temp_path_for_writer = temp_path.clone(); - let write_handle = tokio::task::spawn_blocking(move || -> Result<(), String> { - let file = std::fs::File::create(&temp_path_for_writer) - .map_err(|e| format!("无法创建文件: {}", e))?; - let mut writer = std::io::BufWriter::with_capacity(512 * 1024, file); - let mut write_rx = write_rx; - while let Some(chunk) = write_rx.blocking_recv() { - writer - .write_all(&chunk) - .map_err(|e| format!("写入文件失败: {}", e))?; - } - writer - .flush() - .map_err(|e| format!("刷新写入缓冲区失败: {}", e))?; - writer - .get_ref() - .sync_all() - .map_err(|e| format!("同步文件失败: {}", e))?; - Ok(()) - }); + drop(progress_guard); + let result = core_result?; + let _ = app.emit( + "download-progress", + DownloadProgressEvent { + session_id, + downloaded_size: result.downloaded_size, + total_size: result.total_size, + speed: 0, + progress: 100.0, + }, + ); - // 共享下载字节计数,用于独立的进度上报任务 - let downloaded_shared = Arc::new(AtomicU64::new(0)); - let downloaded_for_emitter = downloaded_shared.clone(); + info!( + "download_file completed: {} bytes -> {} (session {})", + result.downloaded_size, result.actual_save_path, session_id + ); + Ok(DownloadResult { + session_id, + actual_save_path: result.actual_save_path, + detected_filename: result.detected_filename, + }) +} - // 启动独立任务定期上报进度,避免在下载循环中因 emit 阻塞导致”卡卡停停” - let app_for_emitter = app.clone(); +fn start_progress_emitter( + app: tauri::AppHandle, + session_id: u64, + downloaded: Arc, + session_received: Arc, + total: Arc, +) -> ProgressEmitterGuard { let (stop_tx, mut stop_rx) = tokio::sync::oneshot::channel::<()>(); - let progress_guard = ProgressEmitterGuard(Some(stop_tx)); tokio::spawn(async move { - let mut last_downloaded = 0u64; + let mut last_session_received = session_received.load(Ordering::Relaxed); let mut last_instant = tokio::time::Instant::now(); - let mut smoothed_speed: f64 = 0.0; + let mut smoothed_speed = 0.0; const EMA_ALPHA: f64 = 0.3; + loop { tokio::select! { - _ = &mut stop_rx => { - break; - } + _ = &mut stop_rx => break, _ = sleep(Duration::from_millis(100)) => { - let downloaded = downloaded_for_emitter.load(Ordering::Relaxed); + let current = downloaded.load(Ordering::Relaxed); + let current_session_received = session_received.load(Ordering::Relaxed); + let total = total.load(Ordering::Relaxed); let now = tokio::time::Instant::now(); - let elapsed = now.duration_since(last_instant); - if elapsed.as_millis() == 0 { + let elapsed = now.duration_since(last_instant).as_secs_f64(); + if elapsed <= 0.0 { continue; } - - let bytes_in_interval = downloaded.saturating_sub(last_downloaded); - let instant_speed = if elapsed.as_secs_f64() > 0.0 { - bytes_in_interval as f64 / elapsed.as_secs_f64() - } else { - 0.0 - }; - + let instant_speed = current_session_received + .saturating_sub(last_session_received) as f64 + / elapsed; smoothed_speed = if smoothed_speed == 0.0 { instant_speed } else { EMA_ALPHA * instant_speed + (1.0 - EMA_ALPHA) * smoothed_speed }; - let progress = if total > 0 { - ((downloaded as f64 / total as f64) * 100.0).min(100.0) + ((current as f64 / total as f64) * 100.0).min(100.0) } else { 0.0 }; - - let _ = app_for_emitter.emit( + let _ = app.emit( "download-progress", DownloadProgressEvent { session_id, - downloaded_size: downloaded, + downloaded_size: current, total_size: total, speed: smoothed_speed as u64, progress, }, ); - - last_downloaded = downloaded; + last_session_received = current_session_received; last_instant = now; } } } }); - - // 说明:downloaded_shared 仅用于进度上报的近实时采样,对 UI 来说允许“最终一致”, - // 因此这里使用 Relaxed 内存序即可,避免在热路径上引入不必要的全序栅栏。 - - // 流式下载 - let mut stream = response.bytes_stream(); - let mut downloaded: u64 = 0; - let mut download_err: Option = None; - - while let Some(chunk) = stream.next().await { - if DOWNLOAD_CANCELLED.load(Ordering::SeqCst) - || CURRENT_DOWNLOAD_SESSION.load(Ordering::SeqCst) != session_id - { - info!("download_file cancelled (session {})", session_id); - download_err = Some("下载已取消".to_string()); - break; - } - - let chunk = match chunk { - Ok(c) => c, - Err(e) => { - download_err = Some(format!("下载数据失败: {}", e)); - break; - } - }; - - let len = chunk.len() as u64; - if write_tx.send(chunk).await.is_err() { - download_err = Some("磁盘写入线程异常退出".to_string()); - break; - } - downloaded += len; - downloaded_shared.store(downloaded, Ordering::Relaxed); - } - - // 最后再检查一次取消标志 - if download_err.is_none() - && (DOWNLOAD_CANCELLED.load(Ordering::SeqCst) - || CURRENT_DOWNLOAD_SESSION.load(Ordering::SeqCst) != session_id) - { - info!( - "download_file cancelled before finalization (session {})", - session_id - ); - download_err = Some("下载已取消".to_string()); - } - - // 关闭发送端,通知写入线程所有数据已发送完毕 - drop(write_tx); - - // 等待写入线程完成,确保文件句柄关闭后再进行重命名等后续操作 - let write_thread_result = write_handle - .await - .map_err(|e| format!("写入任务异常: {}", e))?; - - if let Some(err) = download_err { - // 写入线程通常持有更具体的 I/O 错误信息(如磁盘满),优先返回 - if let Err(write_err) = write_thread_result { - return Err(write_err); - } - return Err(err); - } - write_thread_result?; - - // 发送最终进度 - let _ = app.emit( - "download-progress", - DownloadProgressEvent { - session_id, - downloaded_size: downloaded, - total_size: if total > 0 { total } else { downloaded }, - speed: 0, - progress: 100.0, - }, - ); - - // 将可能存在的旧文件移动到 old 文件夹 - if actual_save_path_obj.exists() { - let _ = move_to_old_folder(actual_save_path_obj); - } - - // 重命名临时文件(使用异步版本避免阻塞 runtime 线程) - tokio::fs::rename(&temp_path, &actual_save_path) - .await - .map_err(|e| format!("重命名文件失败: {}", e))?; - temp_guard.disarm(); - - info!( - "download_file completed: {} bytes -> {} (session {})", - downloaded, actual_save_path, session_id - ); - - // 显式 drop progress_guard 以停止进度上报任务 - drop(progress_guard); - - Ok(DownloadResult { - session_id, - actual_save_path, - detected_filename, - }) + ProgressEmitterGuard(Some(stop_tx)) } -/// 取消下载 +/// 设置取消标志。半成品会在下载写入线程退出后由下载任务清理。 #[tauri::command] pub fn cancel_download(save_path: String) -> Result<(), String> { info!("cancel_download called for: {}", save_path); - - // 设置取消标志,让下载循环退出 DOWNLOAD_CANCELLED.store(true, Ordering::SeqCst); - - // 同时尝试删除临时文件(如果已经创建) - let temp_path = format!("{}.downloading", save_path); - let path = std::path::Path::new(&temp_path); - - if path.exists() { - if let Err(e) = std::fs::remove_file(path) { - // 文件可能正在被写入,记录警告但不报错 - warn!("cancel_download: failed to remove {}: {}", temp_path, e); - } else { - info!("cancel_download: removed {}", temp_path); - } - } - Ok(()) } - -/// 从 HTTP 响应中提取文件名 -/// -/// 优先级: -/// 1. Content-Disposition header 中的 filename -/// 2. 最终 URL(重定向后)的路径部分 -fn extract_filename_from_response(response: &reqwest::Response) -> Option { - // 1. 尝试从 Content-Disposition header 提取 - if let Some(cd) = response.headers().get("content-disposition") { - if let Ok(cd_str) = cd.to_str() { - if let Some(filename) = parse_content_disposition(cd_str) { - if let Some(safe) = sanitize_filename(&filename) { - return Some(safe); - } - } - } - } - - // 2. 尝试从最终 URL 提取(重定向后的 URL) - let final_url = response.url(); - let path = final_url.path(); - - // 获取路径的最后一部分 - if let Some(last_segment) = path.rsplit('/').next() { - if !last_segment.is_empty() { - // URL 解码 - if let Ok(decoded) = urlencoding::decode(last_segment) { - let filename = decoded.to_string(); - // 确保有扩展名,并清理文件名 - if filename.contains('.') { - if let Some(safe) = sanitize_filename(&filename) { - return Some(safe); - } - } - } - } - } - - None -} - -/// 清理文件名,防止目录遍历攻击 -/// -/// - 移除路径分隔符(/ 和 \) -/// - 移除 .. 片段 -/// - 只保留文件名部分 -fn sanitize_filename(filename: &str) -> Option { - // 获取最后一个路径分隔符后的部分(处理 path/to/file.exe 或 path\to\file.exe) - let name = filename - .rsplit(|c| c == '/' || c == '\\') - .next() - .unwrap_or(filename); - - // 过滤掉 .. 和空文件名 - if name.is_empty() || name == "." || name == ".." || name.starts_with("..") { - return None; - } - - // 确保有扩展名 - if !name.contains('.') { - return None; - } - - Some(name.to_string()) -} - -/// 解析 Content-Disposition header 提取文件名(大小写不敏感) -/// -/// 支持格式: -/// - attachment; filename="example.exe" -/// - attachment; filename=example.exe -/// - attachment; filename*=UTF-8''%E4%B8%AD%E6%96%87.exe -/// - Attachment; Filename="example.exe" (大小写变体) -fn parse_content_disposition(header: &str) -> Option { - let header_lower = header.to_lowercase(); - - // 首先尝试 filename*=(RFC 5987 编码,优先级更高) - if let Some(start) = header_lower.find("filename*=") { - let rest = &header[start + 10..]; - // 格式: UTF-8''encoded_filename 或 utf-8''encoded_filename - if let Some(quote_pos) = rest.find("''") { - let encoded = rest[quote_pos + 2..].split(';').next().unwrap_or("").trim(); - if let Ok(decoded) = urlencoding::decode(encoded) { - let filename = decoded.trim_matches('"').to_string(); - if !filename.is_empty() { - return Some(filename); - } - } - } - } - - // 然后尝试普通的 filename=(但要确保不是 filename*=) - // 查找 "filename=" 但排除 "filename*=" - let mut search_start = 0; - while let Some(pos) = header_lower[search_start..].find("filename=") { - let absolute_pos = search_start + pos; - // 检查是否是 filename*=(前一个字符是 *) - if absolute_pos > 0 && header.as_bytes().get(absolute_pos - 1) == Some(&b'*') { - search_start = absolute_pos + 9; - continue; - } - - let rest = &header[absolute_pos + 9..]; - let filename = rest - .split(';') - .next() - .unwrap_or("") - .trim() - .trim_matches('"') - .to_string(); - if !filename.is_empty() { - return Some(filename); - } - break; - } - - None -} diff --git a/src-tauri/src/commands/download_core.rs b/src-tauri/src/commands/download_core.rs new file mode 100644 index 00000000..8f466a79 --- /dev/null +++ b/src-tauri/src/commands/download_core.rs @@ -0,0 +1,1099 @@ +use futures_util::StreamExt; +use reqwest::header::{CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE}; +use reqwest::{Client, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::fs::{File, OpenOptions}; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +use super::update::move_to_old_folder; + +#[derive(Debug, Clone)] +pub struct DownloadRequest { + pub url: String, + pub save_path: PathBuf, + pub expected_size: Option, + pub resume_key: Option, + pub sha256: Option, + pub session_id: u64, +} + +#[derive(Debug)] +pub struct CoreDownloadResult { + pub actual_save_path: String, + pub detected_filename: Option, + pub downloaded_size: u64, + pub total_size: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ResumeMetadata { + resume_key: String, + expected_size: Option, + sha256: Option, + etag: Option, + last_modified: Option, + detected_filename: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParsedContentRange { + Range { start: u64, end: u64, total: u64 }, + Unsatisfied { total: u64 }, +} + +pub async fn download( + client: &Client, + request: DownloadRequest, + downloaded_shared: Arc, + session_received_shared: Arc, + total_shared: Arc, + is_cancelled: F, +) -> Result +where + F: Fn() -> bool, +{ + if let Some(parent) = request.save_path.parent() { + std::fs::create_dir_all(parent).map_err(|e| format!("无法创建目录: {}", e))?; + } + + let resumable = request.resume_key.is_some(); + let part_path = if resumable { + PathBuf::from(format!("{}.downloading", request.save_path.display())) + } else { + PathBuf::from(format!( + "{}.{}.downloading", + request.save_path.display(), + request.session_id + )) + }; + let metadata_path = PathBuf::from(format!("{}.json", part_path.display())); + + if resumable { + cleanup_legacy_session_files(request.save_path.parent().unwrap_or_else(|| Path::new("."))); + } + + let normalized_sha256 = request.sha256.as_deref().map(normalize_sha256); + let mut metadata = load_compatible_metadata( + &part_path, + &metadata_path, + request.resume_key.as_deref(), + request.expected_size, + normalized_sha256.as_deref(), + ); + let mut resume_offset = metadata + .as_ref() + .and_then(|_| std::fs::metadata(&part_path).ok()) + .map(|m| m.len()) + .unwrap_or(0); + + if let Some(expected) = request.expected_size { + if resume_offset > expected { + remove_download_artifacts(&part_path, &metadata_path); + metadata = None; + resume_offset = 0; + } + } + + let mut response = send_request(client, &request.url, resume_offset, metadata.as_ref()).await?; + + if response.status() == StatusCode::RANGE_NOT_SATISFIABLE && resume_offset > 0 { + let remote_total = response + .headers() + .get(CONTENT_RANGE) + .and_then(|v| v.to_str().ok()) + .and_then(parse_content_range) + .and_then(|v| match v { + ParsedContentRange::Unsatisfied { total } => Some(total), + ParsedContentRange::Range { .. } => None, + }); + + if remote_total == Some(resume_offset) + && request + .expected_size + .map(|expected| expected == resume_offset) + .unwrap_or(true) + { + let detected_filename = extract_filename_from_response(&response) + .or_else(|| metadata.as_ref().and_then(|m| m.detected_filename.clone())); + return finalize_download( + &request, + &part_path, + &metadata_path, + detected_filename, + resume_offset, + normalized_sha256.as_deref(), + ) + .await; + } + + remove_download_artifacts(&part_path, &metadata_path); + metadata = None; + resume_offset = 0; + response = send_request(client, &request.url, 0, None).await?; + } + + if !response.status().is_success() { + return Err(format!("HTTP 错误: {}", response.status())); + } + + let mut append = false; + let response_total = if response.status() == StatusCode::PARTIAL_CONTENT { + let validators_match = metadata + .as_ref() + .map(|metadata| response_validators_match(metadata, &response)) + .unwrap_or(true); + let content_range = response + .headers() + .get(CONTENT_RANGE) + .and_then(|v| v.to_str().ok()) + .and_then(parse_content_range); + + match content_range { + Some(ParsedContentRange::Range { start, end, total }) + if resume_offset > 0 + && start == resume_offset + && end >= start + && end < total + && validators_match => + { + append = true; + Some(total) + } + _ if resume_offset > 0 => { + remove_download_artifacts(&part_path, &metadata_path); + metadata = None; + resume_offset = 0; + response = send_request(client, &request.url, 0, None).await?; + if !response.status().is_success() + || response.status() == StatusCode::PARTIAL_CONTENT + { + return Err(format!( + "服务器返回了无效的 Content-Range: {}", + response.status() + )); + } + response.content_length() + } + _ => { + remove_download_artifacts(&part_path, &metadata_path); + return Err("服务器在未请求断点续传时返回了 206".to_string()); + } + } + } else { + if resume_offset > 0 { + remove_download_artifacts(&part_path, &metadata_path); + metadata = None; + resume_offset = 0; + } + response.content_length() + }; + + if let (Some(expected), Some(remote)) = (request.expected_size, response_total) { + if expected != remote { + remove_download_artifacts(&part_path, &metadata_path); + return Err(format!( + "下载文件大小不匹配: 预期 {} 字节,服务器返回 {} 字节", + expected, remote + )); + } + } + + let total = response_total.or(request.expected_size).unwrap_or(0); + downloaded_shared.store(resume_offset, Ordering::Relaxed); + total_shared.store(total, Ordering::Relaxed); + + let response_filename = extract_filename_from_response(&response); + let detected_filename = + response_filename.or_else(|| metadata.as_ref().and_then(|m| m.detected_filename.clone())); + + if resumable && metadata.is_none() { + let new_metadata = ResumeMetadata { + resume_key: request.resume_key.clone().unwrap_or_default(), + expected_size: request.expected_size, + sha256: normalized_sha256.clone(), + etag: header_string(&response, ETAG), + last_modified: header_string(&response, LAST_MODIFIED), + detected_filename: detected_filename.clone(), + }; + write_metadata_atomic(&metadata_path, &new_metadata, request.session_id)?; + } + + let (write_tx, write_rx) = tokio::sync::mpsc::channel::(64); + let writer_path = part_path.clone(); + let write_handle = tokio::task::spawn_blocking(move || -> Result<(), String> { + let file = OpenOptions::new() + .create(true) + .write(true) + .append(append) + .truncate(!append) + .open(&writer_path) + .map_err(|e| format!("无法创建文件: {}", e))?; + let mut writer = BufWriter::with_capacity(512 * 1024, file); + let mut write_rx = write_rx; + while let Some(chunk) = write_rx.blocking_recv() { + writer + .write_all(&chunk) + .map_err(|e| format!("写入文件失败: {}", e))?; + } + writer + .flush() + .map_err(|e| format!("刷新写入缓冲区失败: {}", e))?; + writer + .get_ref() + .sync_all() + .map_err(|e| format!("同步文件失败: {}", e))?; + Ok(()) + }); + + let mut stream = response.bytes_stream(); + let mut downloaded = resume_offset; + let mut transfer_error = None; + let mut cancelled = false; + + while let Some(chunk) = stream.next().await { + if is_cancelled() { + cancelled = true; + transfer_error = Some("下载已取消".to_string()); + break; + } + + let chunk = match chunk { + Ok(chunk) => chunk, + Err(e) => { + transfer_error = Some(format!("下载数据失败: {}", e)); + break; + } + }; + + let len = chunk.len() as u64; + if write_tx.send(chunk).await.is_err() { + transfer_error = Some("磁盘写入线程异常退出".to_string()); + break; + } + downloaded += len; + downloaded_shared.store(downloaded, Ordering::Relaxed); + session_received_shared.fetch_add(len, Ordering::Relaxed); + } + + if transfer_error.is_none() && is_cancelled() { + cancelled = true; + transfer_error = Some("下载已取消".to_string()); + } + + drop(write_tx); + let write_result = write_handle + .await + .map_err(|e| format!("写入任务异常: {}", e))?; + + if let Err(write_error) = write_result { + if !resumable { + remove_download_artifacts(&part_path, &metadata_path); + } + return Err(write_error); + } + + if let Some(error) = transfer_error { + if cancelled || !resumable { + remove_download_artifacts(&part_path, &metadata_path); + } + return Err(error); + } + + let actual_size = std::fs::metadata(&part_path) + .map_err(|e| format!("无法读取下载文件大小: {}", e))? + .len(); + if total > 0 && actual_size != total { + if actual_size > total || !resumable { + remove_download_artifacts(&part_path, &metadata_path); + } + return Err(format!( + "下载文件不完整: 已下载 {} 字节,总大小 {} 字节", + actual_size, total + )); + } + + finalize_download( + &request, + &part_path, + &metadata_path, + detected_filename, + actual_size, + normalized_sha256.as_deref(), + ) + .await +} + +async fn send_request( + client: &Client, + url: &str, + offset: u64, + metadata: Option<&ResumeMetadata>, +) -> Result { + let mut request = client.get(url); + if offset > 0 { + request = request.header(RANGE, format!("bytes={}-", offset)); + if let Some(if_range) = metadata.and_then(if_range_value) { + request = request.header(IF_RANGE, if_range); + } + } + request.send().await.map_err(|e| format!("请求失败: {}", e)) +} + +async fn finalize_download( + request: &DownloadRequest, + part_path: &Path, + metadata_path: &Path, + detected_filename: Option, + downloaded_size: u64, + expected_sha256: Option<&str>, +) -> Result { + if let Some(expected) = expected_sha256 { + if let Err(error) = verify_sha256(part_path, expected).await { + remove_download_artifacts(part_path, metadata_path); + return Err(error); + } + } + + let actual_save_path = + resolve_actual_save_path(&request.save_path, detected_filename.as_deref()); + if actual_save_path.exists() { + move_to_old_folder(&actual_save_path)?; + } + std::fs::rename(part_path, &actual_save_path).map_err(|e| format!("重命名文件失败: {}", e))?; + let _ = std::fs::remove_file(metadata_path); + + Ok(CoreDownloadResult { + actual_save_path: actual_save_path.to_string_lossy().to_string(), + detected_filename, + downloaded_size, + total_size: downloaded_size, + }) +} + +fn load_compatible_metadata( + part_path: &Path, + metadata_path: &Path, + resume_key: Option<&str>, + expected_size: Option, + sha256: Option<&str>, +) -> Option { + let Some(resume_key) = resume_key else { + remove_download_artifacts(part_path, metadata_path); + return None; + }; + if !part_path.exists() || !metadata_path.exists() { + remove_download_artifacts(part_path, metadata_path); + return None; + } + + let metadata = File::open(metadata_path) + .ok() + .and_then(|file| serde_json::from_reader::<_, ResumeMetadata>(BufReader::new(file)).ok()); + let compatible = metadata.filter(|metadata| { + metadata.resume_key == resume_key + && metadata.expected_size == expected_size + && metadata.sha256.as_deref() == sha256 + }); + + if compatible.is_none() { + remove_download_artifacts(part_path, metadata_path); + } + compatible +} + +fn write_metadata_atomic( + metadata_path: &Path, + metadata: &ResumeMetadata, + session_id: u64, +) -> Result<(), String> { + let temp_path = PathBuf::from(format!("{}.{}.tmp", metadata_path.display(), session_id)); + let mut file = File::create(&temp_path).map_err(|e| format!("无法创建下载元数据: {}", e))?; + serde_json::to_writer(&mut file, metadata).map_err(|e| format!("无法写入下载元数据: {}", e))?; + file.sync_all() + .map_err(|e| format!("无法同步下载元数据: {}", e))?; + std::fs::rename(&temp_path, metadata_path).map_err(|e| { + let _ = std::fs::remove_file(&temp_path); + format!("无法保存下载元数据: {}", e) + }) +} + +fn remove_download_artifacts(part_path: &Path, metadata_path: &Path) { + let _ = std::fs::remove_file(part_path); + let _ = std::fs::remove_file(metadata_path); +} + +fn cleanup_legacy_session_files(directory: &Path) { + let Ok(entries) = std::fs::read_dir(directory) else { + return; + }; + for entry in entries.flatten() { + let path = entry.path(); + let Some(name) = path.file_name().and_then(|name| name.to_str()) else { + continue; + }; + let Some(stem) = name.strip_suffix(".downloading") else { + continue; + }; + let Some(session) = stem.rsplit('.').next() else { + continue; + }; + if !session.is_empty() && session.chars().all(|c| c.is_ascii_digit()) { + let _ = std::fs::remove_file(path); + } + } +} + +fn if_range_value(metadata: &ResumeMetadata) -> Option { + metadata + .etag + .as_ref() + .filter(|etag| !etag.trim_start().starts_with("W/")) + .cloned() + .or_else(|| metadata.last_modified.clone()) +} + +fn response_validators_match(metadata: &ResumeMetadata, response: &Response) -> bool { + if let Some(expected) = metadata + .etag + .as_ref() + .filter(|etag| !etag.trim_start().starts_with("W/")) + { + return header_string(response, ETAG) + .map(|actual| actual == *expected) + .unwrap_or(true); + } + if let Some(expected) = metadata.last_modified.as_ref() { + return header_string(response, LAST_MODIFIED) + .map(|actual| actual == *expected) + .unwrap_or(true); + } + true +} + +fn header_string(response: &Response, name: reqwest::header::HeaderName) -> Option { + response + .headers() + .get(name) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned) +} + +fn parse_content_range(value: &str) -> Option { + let value = value.trim().strip_prefix("bytes ")?; + let (range, total) = value.split_once('/')?; + let total = total.parse().ok()?; + if range == "*" { + return Some(ParsedContentRange::Unsatisfied { total }); + } + let (start, end) = range.split_once('-')?; + Some(ParsedContentRange::Range { + start: start.parse().ok()?, + end: end.parse().ok()?, + total, + }) +} + +fn normalize_sha256(value: &str) -> String { + value + .trim() + .strip_prefix("sha256:") + .unwrap_or(value.trim()) + .to_ascii_lowercase() +} + +async fn verify_sha256(path: &Path, expected: &str) -> Result<(), String> { + let path = path.to_path_buf(); + let expected = expected.to_string(); + tokio::task::spawn_blocking(move || verify_sha256_sync(&path, &expected)) + .await + .map_err(|e| format!("SHA-256 校验任务失败: {}", e))? +} + +fn verify_sha256_sync(path: &Path, expected: &str) -> Result<(), String> { + if expected.len() != 64 || !expected.chars().all(|c| c.is_ascii_hexdigit()) { + return Err("更新包 SHA-256 格式无效".to_string()); + } + let file = File::open(path).map_err(|e| format!("无法读取下载文件: {}", e))?; + let mut reader = BufReader::new(file); + let mut hasher = Sha256::new(); + let mut buffer = [0u8; 64 * 1024]; + loop { + let read = reader + .read(&mut buffer) + .map_err(|e| format!("无法校验下载文件: {}", e))?; + if read == 0 { + break; + } + hasher.update(&buffer[..read]); + } + let actual = format!("{:x}", hasher.finalize()); + if actual != expected { + return Err(format!( + "更新包 SHA-256 校验失败: 预期 {},实际 {}", + expected, actual + )); + } + Ok(()) +} + +fn resolve_actual_save_path(save_path: &Path, detected_filename: Option<&str>) -> PathBuf { + detected_filename + .and_then(|filename| save_path.parent().map(|parent| parent.join(filename))) + .unwrap_or_else(|| save_path.to_path_buf()) +} + +fn extract_filename_from_response(response: &Response) -> Option { + if let Some(cd) = response.headers().get("content-disposition") { + if let Ok(cd_str) = cd.to_str() { + if let Some(filename) = parse_content_disposition(cd_str) { + if let Some(safe) = sanitize_filename(&filename) { + return Some(safe); + } + } + } + } + + let path = response.url().path(); + if let Some(last_segment) = path.rsplit('/').next() { + if !last_segment.is_empty() { + if let Ok(decoded) = urlencoding::decode(last_segment) { + let filename = decoded.to_string(); + if filename.contains('.') { + return sanitize_filename(&filename); + } + } + } + } + None +} + +fn sanitize_filename(filename: &str) -> Option { + let name = filename.rsplit(['/', '\\']).next().unwrap_or(filename); + if name.is_empty() || name == "." || name == ".." || name.starts_with("..") { + return None; + } + name.contains('.').then(|| name.to_string()) +} + +fn parse_content_disposition(header: &str) -> Option { + let header_lower = header.to_lowercase(); + if let Some(start) = header_lower.find("filename*=") { + let rest = &header[start + 10..]; + if let Some(quote_pos) = rest.find("''") { + let encoded = rest[quote_pos + 2..].split(';').next().unwrap_or("").trim(); + if let Ok(decoded) = urlencoding::decode(encoded) { + let filename = decoded.trim_matches('"').to_string(); + if !filename.is_empty() { + return Some(filename); + } + } + } + } + + let mut search_start = 0; + while let Some(pos) = header_lower[search_start..].find("filename=") { + let absolute_pos = search_start + pos; + if absolute_pos > 0 && header.as_bytes().get(absolute_pos - 1) == Some(&b'*') { + search_start = absolute_pos + 9; + continue; + } + let filename = header[absolute_pos + 9..] + .split(';') + .next() + .unwrap_or("") + .trim() + .trim_matches('"') + .to_string(); + if !filename.is_empty() { + return Some(filename); + } + break; + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::Body; + use axum::extract::State; + use axum::http::{HeaderMap, Response as HttpResponse}; + use axum::routing::get; + use axum::Router; + use bytes::Bytes; + use futures_util::stream; + use std::convert::Infallible; + use std::io; + use std::sync::atomic::AtomicBool; + use std::sync::Mutex; + use std::time::Duration; + use tempfile::TempDir; + + #[derive(Debug, Clone, Copy)] + enum ServerMode { + HonorRange, + IgnoreRange, + InvalidRange, + ChangedValidator, + Interrupt, + Slow, + } + + #[derive(Clone)] + struct TestServerState { + data: Arc>, + mode: Arc>, + requests: Arc, Option)>>>, + } + + async fn serve_file( + State(state): State, + headers: HeaderMap, + ) -> HttpResponse { + let range = headers + .get(RANGE) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + let if_range = headers + .get(IF_RANGE) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + state + .requests + .lock() + .unwrap() + .push((range.clone(), if_range.clone())); + + let mode = *state.mode.lock().unwrap(); + let data = state.data.clone(); + if matches!(mode, ServerMode::Interrupt) { + let split = data.len() / 2; + let first_chunk = Bytes::copy_from_slice(&data[..split]); + let body = Body::from_stream(stream::unfold(0, move |step| { + let first_chunk = first_chunk.clone(); + async move { + match step { + 0 => Some((Ok::(first_chunk), 1)), + 1 => { + tokio::time::sleep(Duration::from_millis(25)).await; + Some(( + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "test interruption", + )), + 2, + )) + } + _ => None, + } + } + })); + return response( + StatusCode::OK, + body, + &[ + ("content-length", data.len().to_string()), + ("etag", "\"etag-v1\"".into()), + ], + ); + } + + if matches!(mode, ServerMode::Slow) { + let chunks = data + .chunks(2) + .map(Bytes::copy_from_slice) + .collect::>(); + let body = Body::from_stream(stream::unfold( + chunks.into_iter(), + |mut chunks| async move { + let chunk = chunks.next()?; + tokio::time::sleep(Duration::from_millis(25)).await; + Some((Ok::(chunk), chunks)) + }, + )); + return response( + StatusCode::OK, + body, + &[ + ("content-length", data.len().to_string()), + ("etag", "\"etag-v1\"".into()), + ], + ); + } + + let requested_offset = range.as_deref().and_then(parse_range_offset); + if let Some(offset) = requested_offset { + if matches!(mode, ServerMode::ChangedValidator) + && if_range.as_deref() != Some("\"etag-v2\"") + { + return full_response(&data, "\"etag-v2\""); + } + if matches!(mode, ServerMode::IgnoreRange) { + return full_response(&data, "\"etag-v1\""); + } + if offset >= data.len() { + return response( + StatusCode::RANGE_NOT_SATISFIABLE, + Body::empty(), + &[("content-range", format!("bytes */{}", data.len()))], + ); + } + if matches!(mode, ServerMode::InvalidRange) { + return response( + StatusCode::PARTIAL_CONTENT, + Body::from(data.as_slice().to_vec()), + &[( + "content-range", + format!("bytes 0-{}/{}", data.len() - 1, data.len()), + )], + ); + } + return response( + StatusCode::PARTIAL_CONTENT, + Body::from(data[offset..].to_vec()), + &[ + ( + "content-range", + format!("bytes {}-{}/{}", offset, data.len() - 1, data.len()), + ), + ("etag", "\"etag-v1\"".into()), + ], + ); + } + + full_response(&data, "\"etag-v1\"") + } + + fn response(status: StatusCode, body: Body, headers: &[(&str, String)]) -> HttpResponse { + let mut builder = HttpResponse::builder().status(status); + for (name, value) in headers { + builder = builder.header(*name, value); + } + builder.body(body).unwrap() + } + + fn full_response(data: &[u8], etag: &str) -> HttpResponse { + response( + StatusCode::OK, + Body::from(data.to_vec()), + &[ + ("content-length", data.len().to_string()), + ("etag", etag.to_string()), + ], + ) + } + + fn parse_range_offset(value: &str) -> Option { + value + .strip_prefix("bytes=")? + .strip_suffix('-')? + .parse() + .ok() + } + + async fn start_server( + data: Vec, + mode: ServerMode, + ) -> (String, TestServerState, tokio::task::JoinHandle<()>) { + let state = TestServerState { + data: Arc::new(data), + mode: Arc::new(Mutex::new(mode)), + requests: Arc::new(Mutex::new(Vec::new())), + }; + let app = Router::new() + .route("/update.zip", get(serve_file)) + .with_state(state.clone()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + (format!("http://{}/update.zip", address), state, server) + } + + fn runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + } + + fn request(url: String, save_path: PathBuf, size: Option, key: &str) -> DownloadRequest { + DownloadRequest { + url, + save_path, + expected_size: size, + resume_key: Some(key.to_string()), + sha256: None, + session_id: 1, + } + } + + async fn run_download( + request: DownloadRequest, + cancelled: Arc, + ) -> Result { + download( + &Client::new(), + request, + Arc::new(AtomicU64::new(0)), + Arc::new(AtomicU64::new(0)), + Arc::new(AtomicU64::new(0)), + || cancelled.load(Ordering::SeqCst), + ) + .await + } + + fn seed_partial(path: &Path, bytes: &[u8], key: &str, expected_size: Option, etag: &str) { + let part_path = PathBuf::from(format!("{}.downloading", path.display())); + let metadata_path = PathBuf::from(format!("{}.json", part_path.display())); + std::fs::write(&part_path, bytes).unwrap(); + write_metadata_atomic( + &metadata_path, + &ResumeMetadata { + resume_key: key.to_string(), + expected_size, + sha256: None, + etag: Some(etag.to_string()), + last_modified: None, + detected_filename: Some("update.zip".to_string()), + }, + 99, + ) + .unwrap(); + } + + #[test] + fn resumes_after_interrupted_transfer() { + runtime().block_on(async { + let data = b"0123456789abcdef".to_vec(); + let (url, state, server) = start_server(data.clone(), ServerMode::Interrupt).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + let cancelled = Arc::new(AtomicBool::new(false)); + + let first = run_download( + request( + url.clone(), + save_path.clone(), + Some(data.len() as u64), + "v2", + ), + cancelled.clone(), + ) + .await; + assert!(first.is_err()); + let part_path = PathBuf::from(format!("{}.downloading", save_path.display())); + let partial_size = std::fs::metadata(&part_path).unwrap().len(); + assert!(partial_size > 0 && partial_size < data.len() as u64); + + *state.mode.lock().unwrap() = ServerMode::HonorRange; + run_download( + request(url, save_path.clone(), Some(data.len() as u64), "v2"), + cancelled, + ) + .await + .unwrap(); + + assert_eq!(std::fs::read(&save_path).unwrap(), data); + let requests = state.requests.lock().unwrap(); + assert_eq!( + requests[1].0.as_deref(), + Some(format!("bytes={}-", partial_size).as_str()) + ); + assert_eq!(requests[1].1.as_deref(), Some("\"etag-v1\"")); + server.abort(); + }); + } + + #[test] + fn restarts_when_server_ignores_or_invalidates_range() { + runtime().block_on(async { + for mode in [ + ServerMode::IgnoreRange, + ServerMode::InvalidRange, + ServerMode::ChangedValidator, + ] { + let data = b"complete update package".to_vec(); + let (url, state, server) = start_server(data.clone(), mode).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + seed_partial( + &save_path, + &data[..5], + "v3", + Some(data.len() as u64), + "\"etag-v1\"", + ); + + run_download( + request(url, save_path.clone(), Some(data.len() as u64), "v3"), + Arc::new(AtomicBool::new(false)), + ) + .await + .unwrap(); + assert_eq!(std::fs::read(&save_path).unwrap(), data); + assert!(state.requests.lock().unwrap()[0].0.is_some()); + if matches!(mode, ServerMode::InvalidRange) { + assert_eq!(state.requests.lock().unwrap().len(), 2); + assert!(state.requests.lock().unwrap()[1].0.is_none()); + } + server.abort(); + } + }); + } + + #[test] + fn handles_416_only_for_complete_partial() { + runtime().block_on(async { + let data = b"already complete".to_vec(); + let (url, state, server) = start_server(data.clone(), ServerMode::HonorRange).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + seed_partial( + &save_path, + &data, + "v4", + Some(data.len() as u64), + "\"etag-v1\"", + ); + run_download( + request( + url.clone(), + save_path.clone(), + Some(data.len() as u64), + "v4", + ), + Arc::new(AtomicBool::new(false)), + ) + .await + .unwrap(); + assert_eq!(std::fs::read(&save_path).unwrap(), data); + + let wrong_temp = TempDir::new().unwrap(); + let wrong_path = wrong_temp.path().join("wrong.zip"); + seed_partial( + &wrong_path, + b"oversized stale bytes", + "v5", + None, + "\"etag-v1\"", + ); + let result = run_download( + request(url, wrong_path.clone(), None, "v5"), + Arc::new(AtomicBool::new(false)), + ) + .await + .unwrap(); + assert_eq!(std::fs::read(result.actual_save_path).unwrap(), data); + assert!(state + .requests + .lock() + .unwrap() + .iter() + .any(|request| request.0.is_none())); + server.abort(); + }); + } + + #[test] + fn accepts_matching_sha256() { + runtime().block_on(async { + let data = b"verified update package".to_vec(); + let expected_sha256 = format!("{:x}", Sha256::digest(&data)); + let (url, _state, server) = start_server(data.clone(), ServerMode::HonorRange).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + let mut download_request = request( + url, + save_path.clone(), + Some(data.len() as u64), + "matching-checksum", + ); + download_request.sha256 = Some(expected_sha256); + + run_download(download_request, Arc::new(AtomicBool::new(false))) + .await + .unwrap(); + + assert_eq!(std::fs::read(save_path).unwrap(), data); + server.abort(); + }); + } + + #[test] + fn discards_incompatible_or_corrupt_partial() { + runtime().block_on(async { + let data = b"verified package".to_vec(); + let (url, state, server) = start_server(data.clone(), ServerMode::HonorRange).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + seed_partial( + &save_path, + &data[..4], + "old-version", + Some(data.len() as u64), + "\"etag-v1\"", + ); + run_download( + request( + url.clone(), + save_path.clone(), + Some(data.len() as u64), + "new-version", + ), + Arc::new(AtomicBool::new(false)), + ) + .await + .unwrap(); + assert!(state.requests.lock().unwrap()[0].0.is_none()); + + let corrupt_path = temp.path().join("corrupt.zip"); + let mut corrupt_request = request( + url, + corrupt_path.clone(), + Some(data.len() as u64), + "checksum", + ); + corrupt_request.sha256 = Some("0".repeat(64)); + assert!( + run_download(corrupt_request, Arc::new(AtomicBool::new(false))) + .await + .is_err() + ); + assert!(!corrupt_path.exists()); + assert!(!PathBuf::from(format!("{}.downloading", corrupt_path.display())).exists()); + server.abort(); + }); + } + + #[test] + fn active_cancel_removes_partial_and_metadata() { + runtime().block_on(async { + let data = b"a deliberately slow update package".to_vec(); + let (url, _state, server) = start_server(data.clone(), ServerMode::Slow).await; + let temp = TempDir::new().unwrap(); + let save_path = temp.path().join("update.zip"); + let cancelled = Arc::new(AtomicBool::new(false)); + let task_cancelled = cancelled.clone(); + let task_path = save_path.clone(); + let task = tokio::spawn(async move { + run_download( + request(url, task_path, Some(data.len() as u64), "cancelled"), + task_cancelled, + ) + .await + }); + tokio::time::sleep(Duration::from_millis(60)).await; + cancelled.store(true, Ordering::SeqCst); + assert!(task.await.unwrap().is_err()); + assert!(!PathBuf::from(format!("{}.downloading", save_path.display())).exists()); + assert!(!PathBuf::from(format!("{}.downloading.json", save_path.display())).exists()); + server.abort(); + }); + } +} diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs index e137aece..dc5c2605 100644 --- a/src-tauri/src/commands/mod.rs +++ b/src-tauri/src/commands/mod.rs @@ -19,6 +19,7 @@ pub mod utils; pub mod app_config; pub mod download; +mod download_core; pub mod file_ops; pub mod maa_agent; pub mod maa_core; diff --git a/src-tauri/src/commands/update.rs b/src-tauri/src/commands/update.rs index 5c2fb955..ca47d2ea 100644 --- a/src-tauri/src/commands/update.rs +++ b/src-tauri/src/commands/update.rs @@ -434,7 +434,7 @@ fn copy_dir_recursive(src: &std::path::Path, dst: &std::path::Path) -> Result<() /// 更新完成后清理残留产物: /// 1. 删除 target_dir/changes.json(增量包标识,更新后无需保留) -/// 2. 删除 cache_dir 下所有 *.downloading 临时文件 +/// 2. 删除 cache_dir 下所有断点续传半成品及元数据 #[tauri::command] pub fn cleanup_update_artifacts(target_dir: String, cache_dir: String) -> Result<(), String> { // 删除 target_dir/changes.json @@ -446,7 +446,7 @@ pub fn cleanup_update_artifacts(target_dir: String, cache_dir: String) -> Result } } - // 删除 cache_dir 下所有 *.downloading 文件 + // 删除 cache_dir 下所有下载半成品、元数据和未完成的元数据临时文件 let cache_path = std::path::Path::new(&cache_dir); if cache_path.exists() { if let Ok(entries) = std::fs::read_dir(cache_path) { @@ -454,7 +454,10 @@ pub fn cleanup_update_artifacts(target_dir: String, cache_dir: String) -> Result let path = entry.path(); if path.is_file() { let name = path.file_name().unwrap_or_default().to_string_lossy(); - if name.ends_with(".downloading") { + if name.ends_with(".downloading") + || name.ends_with(".downloading.json") + || (name.contains(".downloading.json.") && name.ends_with(".tmp")) + { match std::fs::remove_file(&path) { Ok(()) => info!("已删除临时下载文件: {}", path.display()), Err(e) => warn!("删除临时下载文件失败(忽略): {}", e), diff --git a/src/App.tsx b/src/App.tsx index cea49cb9..b251b9ef 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -456,6 +456,8 @@ function App() { url: updateResult.downloadUrl, savePath, totalSize: updateResult.fileSize, + resumeKey: updateResult.resumeKey, + sha256: updateResult.sha256, proxySettings: proxyForDownload, onProgress: (progress: DownloadProgress) => { setDownloadProgress(progress); diff --git a/src/components/UpdatePanel.tsx b/src/components/UpdatePanel.tsx index daaf3b58..9490421c 100644 --- a/src/components/UpdatePanel.tsx +++ b/src/components/UpdatePanel.tsx @@ -71,6 +71,8 @@ export function UpdatePanel({ onClose, anchorRef }: UpdatePanelProps) { url: updateInfo.downloadUrl, savePath, totalSize: updateInfo.fileSize, + resumeKey: updateInfo.resumeKey, + sha256: updateInfo.sha256, proxySettings: proxyForDownload, onProgress: (progress: DownloadProgress) => { setDownloadProgress(progress); diff --git a/src/components/settings/UpdateSection.tsx b/src/components/settings/UpdateSection.tsx index 493ff1a6..de3081e0 100644 --- a/src/components/settings/UpdateSection.tsx +++ b/src/components/settings/UpdateSection.tsx @@ -142,6 +142,8 @@ export function UpdateSection() { url: info.downloadUrl, savePath, totalSize: info.fileSize, + resumeKey: info.resumeKey, + sha256: info.sha256, proxySettings: proxyForDownload, onProgress: (progress) => { setDownloadProgress(progress); diff --git a/src/services/proxyService.ts b/src/services/proxyService.ts index 8db61d11..27cf7be6 100644 --- a/src/services/proxyService.ts +++ b/src/services/proxyService.ts @@ -171,6 +171,8 @@ export async function downloadWithProxy( options?: { totalSize?: number; proxyUrl?: string | null; + resumeKey?: string; + sha256?: string; }, ): Promise { const hasProxy = options?.proxyUrl && options.proxyUrl.trim() !== ''; @@ -190,5 +192,7 @@ export async function downloadWithProxy( savePath, totalSize: options?.totalSize || null, proxyUrl: options?.proxyUrl || null, + resumeKey: options?.resumeKey || null, + sha256: options?.sha256 || null, }); } diff --git a/src/services/updateService.ts b/src/services/updateService.ts index 921e259e..dfee9066 100644 --- a/src/services/updateService.ts +++ b/src/services/updateService.ts @@ -13,7 +13,7 @@ import { openPath, openUrl } from '@tauri-apps/plugin-opener'; import * as semver from 'semver'; import { backupConfigBeforeUpdate } from './configService'; -import { downloadWithProxy } from './proxyService'; +import { downloadWithProxy, type DownloadResult } from './proxyService'; const log = loggers.app; @@ -21,6 +21,8 @@ const log = loggers.app; let isDownloading = false; // 下载是否被用户主动取消 let downloadCancelled = false; +// 当前 Rust 下载 Promise;主动取消必须等待它结束,才能安全开始下一次下载。 +let currentDownloadPromise: Promise | null = null; // 安装互斥锁,防止并发 installUpdate 导致目录竞争 let isInstalling = false; @@ -57,6 +59,7 @@ export async function cancelDownload(): Promise { // 标记为用户主动取消,让 downloadUpdate 知道不需要再重置状态 downloadCancelled = true; + const downloadPromise = currentDownloadPromise; try { // 调用 Rust 后端设置取消标志 await invoke('cancel_download', { savePath: currentDownloadPath }); @@ -64,9 +67,10 @@ export async function cancelDownload(): Promise { log.warn('取消下载失败:', error); } - // 立即重置状态,允许新的下载开始 - isDownloading = false; - currentDownloadPath = null; + // 等待下载流和磁盘写入线程结束;downloadUpdate 的 finally 负责重置状态。 + if (downloadPromise) { + await downloadPromise.catch(() => undefined); + } return true; } @@ -207,6 +211,26 @@ function getArch(): string { return 'amd64'; } +function withResumeKey(updateInfo: UpdateInfo, resourceId: string): UpdateInfo { + if (!updateInfo.downloadUrl || !updateInfo.downloadSource) { + return updateInfo; + } + return { + ...updateInfo, + resumeKey: JSON.stringify([ + resourceId, + updateInfo.downloadSource, + updateInfo.versionName, + updateInfo.channel || 'stable', + updateInfo.updateType || 'full', + getOS(), + getArch(), + updateInfo.filename || '', + updateInfo.fileSize || 0, + ]), + }; +} + export interface CheckUpdateOptions { resourceId: string; // mirrorchyan_rid currentVersion: string; // 当前版本 @@ -358,6 +382,7 @@ export async function checkUpdate(options: CheckUpdateOptions): Promise void; proxySettings?: ProxySettings; // 代理设置 + resumeKey?: string; + sha256?: string; } // 当前下载的保存路径,用于取消时清理临时文件 @@ -739,7 +766,7 @@ export async function downloadUpdate( return { success: false }; } - const { url, savePath, totalSize, onProgress, proxySettings } = options; + const { url, savePath, totalSize, onProgress, proxySettings, resumeKey, sha256 } = options; log.info(`开始下载更新: ${url}`); log.info(`保存路径: ${savePath}`); @@ -750,15 +777,19 @@ export async function downloadUpdate( // 设置进度监听器 let unlisten: (() => void) | null = null; + let downloadPromise: Promise | null = null; // 当前下载的 session ID,用于过滤旧下载的进度事件 let currentSessionId: number | null = null; try { // 使用统一的代理下载接口(内部已包含日志记录) - const downloadPromise = downloadWithProxy(url, savePath, { + downloadPromise = downloadWithProxy(url, savePath, { totalSize, proxyUrl: proxySettings?.url, + resumeKey, + sha256, }); + currentDownloadPromise = downloadPromise; // 监听 Rust 后端发送的下载进度事件 if (onProgress) { @@ -811,10 +842,11 @@ export async function downloadUpdate( if (unlisten) { unlisten(); } - // 只有在未被取消时才重置状态(取消时 cancelDownload 已经重置了) - if (!downloadCancelled) { + if (downloadPromise && currentDownloadPromise === downloadPromise) { + currentDownloadPromise = null; isDownloading = false; currentDownloadPath = null; + downloadCancelled = false; } } } @@ -839,10 +871,11 @@ export async function checkAndPrepareDownload( return null; } - const { githubUrl, cdk, channel, githubPat, projectName, proxyUrl, ...checkOptions } = options; + const { resourceId, githubUrl, cdk, channel, githubPat, projectName, proxyUrl, ...checkOptions } = + options; // 始终使用 Mirror酱 检查更新 - const updateInfo = await checkUpdate({ ...checkOptions, cdk, channel }); + const updateInfo = await checkUpdate({ ...checkOptions, resourceId, cdk, channel }); if (!updateInfo || !updateInfo.hasUpdate) { return updateInfo; @@ -851,7 +884,7 @@ export async function checkAndPrepareDownload( // 如果有 CDK 且返回了下载链接,直接使用 if (cdk && updateInfo.downloadUrl) { log.info('使用 Mirror酱 下载链接'); - return updateInfo; + return withResumeKey(updateInfo, resourceId); } // 如果有错误码(如 CDK 问题),不尝试 GitHub,直接返回更新信息(包含错误) @@ -872,13 +905,17 @@ export async function checkAndPrepareDownload( }); if (githubDownload) { - return { - ...updateInfo, - downloadUrl: githubDownload.url, - fileSize: githubDownload.size, - filename: githubDownload.filename, - downloadSource: 'github', - }; + return withResumeKey( + { + ...updateInfo, + downloadUrl: githubDownload.url, + fileSize: githubDownload.size, + filename: githubDownload.filename, + downloadSource: 'github', + sha256: undefined, + }, + resourceId, + ); } log.warn('GitHub 下载链接获取失败'); diff --git a/src/stores/types.ts b/src/stores/types.ts index 2448fc4a..ab90a93b 100644 --- a/src/stores/types.ts +++ b/src/stores/types.ts @@ -56,6 +56,10 @@ export interface UpdateInfo { fileSize?: number; filename?: string; downloadSource?: 'mirrorchyan' | 'github'; + /** 服务端提供的更新包 SHA-256,仅 MirrorChyan 下载可用。 */ + sha256?: string; + /** 标识可安全复用的同一更新包,用于断点续传。 */ + resumeKey?: string; // MirrorChyan API 错误信息 errorCode?: number; errorMessage?: string;