From 0cc3a9b8693f365a44d5d690e2d38fd12205bdc7 Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn Date: Thu, 18 Jun 2026 13:37:44 -0400 Subject: [PATCH] Fix #49: stream downloads to disk + optional progress_callback Previously _download read the entire HTTP/FTP response into memory (response.content / response.read()) before writing it to the temp file, so downloading a large reference file (e.g. a ~200 MB HPA archive) spiked RAM by the full file size and gave callers no way to show progress. Replace _download with _stream_to_file, which streams the response to an open file handle one chunk at a time (requests stream=True + iter_content for http, chunked response.read for ftp/file), capping memory at one chunk. Add an optional progress_callback(bytes_downloaded, total_bytes) hook and chunk_size, threaded through _download_to_temp_file, _download_and_decompress_if_necessary and fetch_file, so consumers can drive a tqdm bar without datacache depending on tqdm. Adds network-free regression tests using local file:// URLs. Claude-Session: https://claude.ai/code/session_011bzfZPTzWnhAMVD7msyMg1 --- datacache/download.py | 100 ++++++++++++++++++++++++++---- tests/test_streaming_download.py | 101 +++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 11 deletions(-) create mode 100644 tests/test_streaming_download.py diff --git a/datacache/download.py b/datacache/download.py index f3d3a61..6abe7e5 100644 --- a/datacache/download.py +++ b/datacache/download.py @@ -27,16 +27,69 @@ logger = logging.getLogger(__name__) +# Number of bytes to read/write at a time when streaming a download to disk. +DEFAULT_CHUNK_SIZE = 2 ** 20 # 1 MB + + +def _content_length(header_value): + """Parse a Content-Length header value into an int, or None if it's + absent or not a valid integer.""" + if header_value is None: + return None + try: + return int(header_value) + except (TypeError, ValueError): + return None + + +def _stream_to_file( + download_url, + file_handle, + timeout=None, + chunk_size=DEFAULT_CHUNK_SIZE, + progress_callback=None): + """ + Stream the contents of `download_url` into an already-open binary file + handle, one chunk at a time, so the entire payload never has to be held in + memory at once. + + If `progress_callback` is given it is called as + ``progress_callback(bytes_downloaded, total_bytes)`` after each chunk is + written, where `total_bytes` is taken from the server's Content-Length + header (or None when the server doesn't report a size). This lets callers + drive e.g. a tqdm progress bar without datacache depending on tqdm. + + Returns the total number of bytes written. + """ + bytes_downloaded = 0 + + def report(total_bytes): + if progress_callback is not None: + progress_callback(bytes_downloaded, total_bytes) -def _download(download_url, timeout=None): if download_url.startswith("http"): - response = requests.get(download_url, timeout=timeout) + response = requests.get(download_url, timeout=timeout, stream=True) response.raise_for_status() - return response.content + total_bytes = _content_length(response.headers.get("Content-Length")) + for chunk in response.iter_content(chunk_size=chunk_size): + if not chunk: + # skip keep-alive chunks that carry no data + continue + file_handle.write(chunk) + bytes_downloaded += len(chunk) + report(total_bytes) else: req = urllib.request.Request(download_url) response = urllib.request.urlopen(req, data=None, timeout=timeout) - return response.read() + total_bytes = _content_length(response.headers.get("Content-Length")) + while True: + chunk = response.read(chunk_size) + if not chunk: + break + file_handle.write(chunk) + bytes_downloaded += len(chunk) + report(total_bytes) + return bytes_downloaded def _download_to_temp_file( @@ -44,7 +97,9 @@ def _download_to_temp_file( timeout=None, base_name="download", ext="tmp", - use_wget_if_available=False): + use_wget_if_available=False, + chunk_size=DEFAULT_CHUNK_SIZE, + progress_callback=None): if not download_url: raise ValueError("URL not provided") @@ -57,8 +112,12 @@ def _download_to_temp_file( def download_using_python(): with open(tmp_path, mode="w+b") as tmp_file: - tmp_file.write( - _download(download_url, timeout=timeout)) + _stream_to_file( + download_url, + tmp_file, + timeout=timeout, + chunk_size=chunk_size, + progress_callback=progress_callback) if not use_wget_if_available: download_using_python() @@ -91,7 +150,9 @@ def _download_and_decompress_if_necessary( full_path, download_url, timeout=None, - use_wget_if_available=False): + use_wget_if_available=False, + chunk_size=DEFAULT_CHUNK_SIZE, + progress_callback=None): """ Downloads remote file at `download_url` to local file at `full_path` """ @@ -103,7 +164,9 @@ def _download_and_decompress_if_necessary( timeout=timeout, base_name=base_name, ext=ext, - use_wget_if_available=use_wget_if_available) + use_wget_if_available=use_wget_if_available, + chunk_size=chunk_size, + progress_callback=progress_callback) if download_url.endswith("zip") and not filename.endswith("zip"): logger.info("Decompressing zip into %s...", filename) @@ -159,7 +222,9 @@ def fetch_file( subdir=None, force=False, timeout=None, - use_wget_if_available=False): + use_wget_if_available=False, + chunk_size=DEFAULT_CHUNK_SIZE, + progress_callback=None): """ Download a remote file and store it locally in a cache directory. Don't download it again if it's already present (unless `force` is True.) @@ -194,6 +259,17 @@ def fetch_file( If the `wget` command is available, use that for download instead of Python libraries (default True) + chunk_size : int, optional + Number of bytes to stream from the server to disk at a time when + using the Python downloader. Defaults to 1 MB. + + progress_callback : callable, optional + If provided, called as ``progress_callback(bytes_downloaded, + total_bytes)`` after each chunk is written, where `total_bytes` is the + server-reported size or None if unknown. Lets callers render a progress + bar (e.g. tqdm) without datacache taking on that dependency. Only + applies to the Python downloader, not the optional `wget` path. + Returns the full path of the local file. """ filename = build_local_filename(download_url, filename, decompress) @@ -204,7 +280,9 @@ def fetch_file( full_path=full_path, download_url=download_url, timeout=timeout, - use_wget_if_available=use_wget_if_available) + use_wget_if_available=use_wget_if_available, + chunk_size=chunk_size, + progress_callback=progress_callback) else: logger.info("Cached file %s from URL %s", filename, download_url) return full_path diff --git a/tests/test_streaming_download.py b/tests/test_streaming_download.py new file mode 100644 index 0000000..c77a109 --- /dev/null +++ b/tests/test_streaming_download.py @@ -0,0 +1,101 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for streaming downloads and the progress_callback hook +(regression tests for https://github.com/openvax/datacache/issues/49). + +These use local file:// URLs so they don't depend on the network. +""" + +import os +from tempfile import NamedTemporaryFile, mkdtemp + +from datacache.download import ( + _stream_to_file, + _download_to_temp_file, + _download_and_decompress_if_necessary, +) + + +def _make_temp_file(contents=b""): + with NamedTemporaryFile(delete=False) as f: + f.write(contents) + return f.name + + +def _file_url(path): + return "file://" + path + + +def test_stream_to_file_writes_all_bytes_and_reports_progress(): + contents = b"abc" * 5000 # 15000 bytes + src_path = _make_temp_file(contents) + dst_path = _make_temp_file() + progress = [] + try: + with open(dst_path, "wb") as dst: + total = _stream_to_file( + _file_url(src_path), + dst, + chunk_size=4096, + progress_callback=lambda done, total: progress.append((done, total))) + assert total == len(contents) + with open(dst_path, "rb") as f: + assert f.read() == contents + # callback fired and the final report equals the full size + assert progress, "progress_callback was never called" + assert progress[-1][0] == len(contents) + # byte counts are monotonically non-decreasing + counts = [done for (done, _total) in progress] + assert counts == sorted(counts) + # more than one chunk was written for a 15kB file at 4kB chunks + assert len(progress) > 1 + finally: + os.remove(src_path) + os.remove(dst_path) + + +def test_download_to_temp_file_streams_local_file(): + contents = b"hello streaming world" + src_path = _make_temp_file(contents) + tmp_path = None + try: + tmp_path = _download_to_temp_file(_file_url(src_path)) + with open(tmp_path, "rb") as f: + assert f.read() == contents + finally: + os.remove(src_path) + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_download_and_decompress_threads_progress_callback(): + contents = b"x" * 1234 + src_path = _make_temp_file(contents) + out_dir = mkdtemp() + out_path = os.path.join(out_dir, "out.bin") + progress = [] + try: + _download_and_decompress_if_necessary( + full_path=out_path, + download_url=_file_url(src_path), + chunk_size=100, + progress_callback=lambda done, _total: progress.append(done)) + with open(out_path, "rb") as f: + assert f.read() == contents + assert progress and progress[-1] == len(contents) + finally: + os.remove(src_path) + if os.path.exists(out_path): + os.remove(out_path) + os.rmdir(out_dir)