Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 89 additions & 11 deletions datacache/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,79 @@

logger = logging.getLogger(__name__)

# Number of bytes to read/write at a time when streaming a download to disk.
DEFAULT_CHUNK_SIZE = 2 ** 20 # 1 MB


def _content_length(header_value):
"""Parse a Content-Length header value into an int, or None if it's
absent or not a valid integer."""
if header_value is None:
return None
try:
return int(header_value)
except (TypeError, ValueError):
return None


def _stream_to_file(
download_url,
file_handle,
timeout=None,
chunk_size=DEFAULT_CHUNK_SIZE,
progress_callback=None):
"""
Stream the contents of `download_url` into an already-open binary file
handle, one chunk at a time, so the entire payload never has to be held in
memory at once.

If `progress_callback` is given it is called as
``progress_callback(bytes_downloaded, total_bytes)`` after each chunk is
written, where `total_bytes` is taken from the server's Content-Length
header (or None when the server doesn't report a size). This lets callers
drive e.g. a tqdm progress bar without datacache depending on tqdm.

Returns the total number of bytes written.
"""
bytes_downloaded = 0

def report(total_bytes):
if progress_callback is not None:
progress_callback(bytes_downloaded, total_bytes)

def _download(download_url, timeout=None):
if download_url.startswith("http"):
response = requests.get(download_url, timeout=timeout)
response = requests.get(download_url, timeout=timeout, stream=True)
response.raise_for_status()
return response.content
total_bytes = _content_length(response.headers.get("Content-Length"))
for chunk in response.iter_content(chunk_size=chunk_size):
if not chunk:
# skip keep-alive chunks that carry no data
continue
file_handle.write(chunk)
bytes_downloaded += len(chunk)
report(total_bytes)
else:
req = urllib.request.Request(download_url)
response = urllib.request.urlopen(req, data=None, timeout=timeout)
return response.read()
total_bytes = _content_length(response.headers.get("Content-Length"))
while True:
chunk = response.read(chunk_size)
if not chunk:
break
file_handle.write(chunk)
bytes_downloaded += len(chunk)
report(total_bytes)
return bytes_downloaded


def _download_to_temp_file(
download_url,
timeout=None,
base_name="download",
ext="tmp",
use_wget_if_available=False):
use_wget_if_available=False,
chunk_size=DEFAULT_CHUNK_SIZE,
progress_callback=None):

if not download_url:
raise ValueError("URL not provided")
Expand All @@ -57,8 +112,12 @@ def _download_to_temp_file(

def download_using_python():
with open(tmp_path, mode="w+b") as tmp_file:
tmp_file.write(
_download(download_url, timeout=timeout))
_stream_to_file(
download_url,
tmp_file,
timeout=timeout,
chunk_size=chunk_size,
progress_callback=progress_callback)

if not use_wget_if_available:
download_using_python()
Expand Down Expand Up @@ -91,7 +150,9 @@ def _download_and_decompress_if_necessary(
full_path,
download_url,
timeout=None,
use_wget_if_available=False):
use_wget_if_available=False,
chunk_size=DEFAULT_CHUNK_SIZE,
progress_callback=None):
"""
Downloads remote file at `download_url` to local file at `full_path`
"""
Expand All @@ -103,7 +164,9 @@ def _download_and_decompress_if_necessary(
timeout=timeout,
base_name=base_name,
ext=ext,
use_wget_if_available=use_wget_if_available)
use_wget_if_available=use_wget_if_available,
chunk_size=chunk_size,
progress_callback=progress_callback)

if download_url.endswith("zip") and not filename.endswith("zip"):
logger.info("Decompressing zip into %s...", filename)
Expand Down Expand Up @@ -159,7 +222,9 @@ def fetch_file(
subdir=None,
force=False,
timeout=None,
use_wget_if_available=False):
use_wget_if_available=False,
chunk_size=DEFAULT_CHUNK_SIZE,
progress_callback=None):
"""
Download a remote file and store it locally in a cache directory. Don't
download it again if it's already present (unless `force` is True.)
Expand Down Expand Up @@ -194,6 +259,17 @@ def fetch_file(
If the `wget` command is available, use that for download instead
of Python libraries (default True)

chunk_size : int, optional
Number of bytes to stream from the server to disk at a time when
using the Python downloader. Defaults to 1 MB.

progress_callback : callable, optional
If provided, called as ``progress_callback(bytes_downloaded,
total_bytes)`` after each chunk is written, where `total_bytes` is the
server-reported size or None if unknown. Lets callers render a progress
bar (e.g. tqdm) without datacache taking on that dependency. Only
applies to the Python downloader, not the optional `wget` path.

Returns the full path of the local file.
"""
filename = build_local_filename(download_url, filename, decompress)
Expand All @@ -204,7 +280,9 @@ def fetch_file(
full_path=full_path,
download_url=download_url,
timeout=timeout,
use_wget_if_available=use_wget_if_available)
use_wget_if_available=use_wget_if_available,
chunk_size=chunk_size,
progress_callback=progress_callback)
else:
logger.info("Cached file %s from URL %s", filename, download_url)
return full_path
Expand Down
101 changes: 101 additions & 0 deletions tests/test_streaming_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for streaming downloads and the progress_callback hook
(regression tests for https://github.com/openvax/datacache/issues/49).

These use local file:// URLs so they don't depend on the network.
"""

import os
from tempfile import NamedTemporaryFile, mkdtemp

from datacache.download import (
_stream_to_file,
_download_to_temp_file,
_download_and_decompress_if_necessary,
)


def _make_temp_file(contents=b""):
with NamedTemporaryFile(delete=False) as f:
f.write(contents)
return f.name


def _file_url(path):
return "file://" + path


def test_stream_to_file_writes_all_bytes_and_reports_progress():
contents = b"abc" * 5000 # 15000 bytes
src_path = _make_temp_file(contents)
dst_path = _make_temp_file()
progress = []
try:
with open(dst_path, "wb") as dst:
total = _stream_to_file(
_file_url(src_path),
dst,
chunk_size=4096,
progress_callback=lambda done, total: progress.append((done, total)))
assert total == len(contents)
with open(dst_path, "rb") as f:
assert f.read() == contents
# callback fired and the final report equals the full size
assert progress, "progress_callback was never called"
assert progress[-1][0] == len(contents)
# byte counts are monotonically non-decreasing
counts = [done for (done, _total) in progress]
assert counts == sorted(counts)
# more than one chunk was written for a 15kB file at 4kB chunks
assert len(progress) > 1
finally:
os.remove(src_path)
os.remove(dst_path)


def test_download_to_temp_file_streams_local_file():
contents = b"hello streaming world"
src_path = _make_temp_file(contents)
tmp_path = None
try:
tmp_path = _download_to_temp_file(_file_url(src_path))
with open(tmp_path, "rb") as f:
assert f.read() == contents
finally:
os.remove(src_path)
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)


def test_download_and_decompress_threads_progress_callback():
contents = b"x" * 1234
src_path = _make_temp_file(contents)
out_dir = mkdtemp()
out_path = os.path.join(out_dir, "out.bin")
progress = []
try:
_download_and_decompress_if_necessary(
full_path=out_path,
download_url=_file_url(src_path),
chunk_size=100,
progress_callback=lambda done, _total: progress.append(done))
with open(out_path, "rb") as f:
assert f.read() == contents
assert progress and progress[-1] == len(contents)
finally:
os.remove(src_path)
if os.path.exists(out_path):
os.remove(out_path)
os.rmdir(out_dir)
Loading