Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 88 additions & 129 deletions paderbox/io/download.py
Original file line number Diff line number Diff line change
@@ -1,184 +1,143 @@

import os
import socket
import sys
import tarfile
import zipfile
import warnings
from pathlib import Path
from urllib.request import urlretrieve
from concurrent.futures import ProcessPoolExecutor

from tqdm import tqdm


def download_file(remote_file, local_file, exist_ok=False):
def download_file(remote_file, local_file, exist_ok=False, extract=False):
"""
Download single file to local_dir

Args:
remote_file:
local_file:
exist_ok:
extract:
progress_par:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

progress_bar is no arg


Returns:

"""
local_file = Path(local_file)
if not local_file.exists():
def progress_hook(t):
"""
https://raw.githubusercontent.com/tqdm/tqdm/master/examples/tqdm_wget.py

Wraps tqdm instance. Don't forget to close() or __exit__()
the tqdm instance once you're done with it (easiest using
`with` syntax).
"""

last_b = 0

def inner(b=1, bsize=1, tsize=None):
"""
b : int, optional
Number of blocks just transferred [default: 1].
bsize : int, optional
Size of each block (in tqdm units) [default: 1].
tsize : int, optional
Total size (in tqdm units). If [default: None]
remains unchanged.
"""
nonlocal last_b
if tsize is not None:
t.total = tsize
t.update((b - last_b) * bsize)
last_b = b

return inner

tmp_file = str(local_file) + '.tmp'
with tqdm(
desc="{0: >25s}".format(Path(remote_file).stem),
file=sys.stdout,
unit='B',
unit_scale=True,
miniters=1,
leave=False,
ascii=True
) as t:
urlretrieve(
str(remote_file),
filename=tmp_file,
reporthook=progress_hook(t),
data=None
)
urlretrieve(
str(remote_file),
filename=tmp_file,
data=None
)
os.rename(tmp_file, local_file)
elif not exist_ok:
raise FileExistsError(local_file)
if extract:
extract_file(local_file, exist_ok=exist_ok)
return local_file


def extract_file(local_file, exist_ok=False):
def extract_file(local_file, target_dir=None, exist_ok=False):
"""
If local_file is .zip or .tar.gz files are extracted.

Args:
local_file:
target_dir:
exist_ok:

Returns:

"""
local_file = Path(local_file)
local_dir = local_file.parent
if local_file.exists():

if local_file.name.endswith('.zip'):
with zipfile.ZipFile(local_file, "r") as z:
# Start extraction
members = z.infolist()
for i, member in enumerate(members):
target_file = local_dir / member.filename
if not target_file.exists():
try:
z.extract(member=member, path=local_dir)
except KeyboardInterrupt:
# Delete latest file, since most likely it
# was not extracted fully
if target_file.exists():
os.remove(target_file)
raise
elif not exist_ok:
raise FileExistsError(target_file)
os.remove(local_file)

elif local_file.name.endswith('.tar.gz'):
with tarfile.open(local_file, "r:gz") as tar:
for i, tar_info in enumerate(tar):
target_file = local_dir / tar_info.name
if not target_file.exists():
try:
tar.extract(tar_info, local_dir)
except KeyboardInterrupt:
# Delete latest file, since most likely it
# was not extracted fully
if target_file.exists():
os.remove(target_file)
raise
elif not exist_ok:
raise FileExistsError(target_file)
tar.members = []
os.remove(local_file)


def download_file_list(file_list, target_dir, exist_ok=False, logger=None):
assert local_file.exists(), local_file
if target_dir is None:
target_dir = local_file.parent
else:
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
if local_file.name.endswith('.zip'):
with zipfile.ZipFile(local_file, "r") as z:
# Start extraction
members = z.infolist()
for i, member in enumerate(members):
target_file = target_dir / member.filename
if not target_file.exists():
try:
z.extract(member=member, path=target_dir)
except KeyboardInterrupt:
# Delete latest file, since most likely it
# was not extracted fully
if target_file.exists():
os.remove(target_file)
raise
elif not exist_ok:
raise FileExistsError(target_file)
os.remove(local_file)

elif local_file.name.endswith('.tar.gz') or local_file.name.endswith('.tar'):
mode = "r:gz" if local_file.name.endswith('.tar.gz') else "r"
with tarfile.open(local_file, mode) as tar:
for i, tar_info in enumerate(tar):
target_file = target_dir / tar_info.name
if not target_file.exists():
try:
tar.extract(tar_info, target_dir)
except KeyboardInterrupt:
# Delete latest file, since most likely it
# was not extracted fully
if target_file.exists():
os.remove(target_file)
raise
elif not exist_ok:
raise FileExistsError(target_file)
tar.members = []
os.remove(local_file)
else:
warnings.warn("Unsupported file format: Cannot extract file.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason, that you didn't define an else path with a raise? If not, could you raise a warning, when there is an unsupported suffix?


def download_file_list(file_list, target_dir, extract=True, exist_ok=False, num_workers=1):
"""
Download file_list to target_dir

Args:
file_list:
target_dir:
exist_ok:
logger:
extract:
num_workers:

Returns:

"""

target_dir = Path(target_dir)
os.makedirs(target_dir, exist_ok=True)

item_progress = tqdm(
file_list, desc="{0: <25s}".format('Download files'),
file=sys.stdout, leave=False, ascii=True)

local_files = list()
for remote_file in item_progress:
local_files.append(
download_file(
remote_file,
target_dir / Path(remote_file).name,
exist_ok=exist_ok
)
)

item_progress = tqdm(
local_files,
desc="{0: <25s}".format('Extract files'),
file=sys.stdout,
leave=False,
ascii=True
)

if logger is not None:
logger.info('Starting Extraction')
for _id, local_file in enumerate(item_progress):
if local_file and local_file.exists():
if logger is not None:
logger.info(
' {title:<15s} [{item_id:d}/{total:d}] {package:<30s}'
.format(
title='Extract files ',
item_id=_id,
total=len(item_progress),
package=local_file
)
)
extract_file(local_file, exist_ok=exist_ok)
pbar = tqdm(initial=0, total=len(file_list))

if isinstance(extract, bool):
extract = len(file_list) * [extract]
assert len(extract) == len(file_list), (len(extract), len(file_list))
if num_workers > 1:
with ProcessPoolExecutor(num_workers) as ex:
for _ in ex.map(
download_file,
file_list,
[target_dir / Path(f).name.split('?')[0] for f in file_list], # extract file names from urls discarding query strings
len(file_list) * [exist_ok],
extract,
):
pbar.update(1)
else:
for _ in map(
download_file,
file_list,
[target_dir / Path(f).name.split('?')[0] for f in file_list],
len(file_list) * [exist_ok],
extract,
):
pbar.update(1)
2 changes: 1 addition & 1 deletion paderbox/transform/module_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def resample_sox(signal: np.ndarray, *, in_rate, out_rate, normalize=True):
# input signal is much too large.
# We normalize each channel independently to avoid rounding errors leading
# to the channel doc test above to fail randomly.
normalizer = 0.95 / np.max(np.abs(signal), keepdims=True, axis=-1)
normalizer = 0.95 / np.maximum(np.max(np.abs(signal), keepdims=True, axis=-1), 1e-12)
signal = normalizer * signal

sox_type = {
Expand Down
9 changes: 8 additions & 1 deletion paderbox/transform/module_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def _samples_to_stft_frames(
2
>>> _samples_to_stft_frames(21, 16, 4)
3
>>> _samples_to_stft_frames(0, 16, 4)
0

>>> stft(np.zeros(19), 16, 4, fading=None).shape
(2, 9)
Expand Down Expand Up @@ -275,12 +277,13 @@ def _samples_to_stft_frames(
pad_width = (size - shift)
samples = samples + (1 + (fading != 'half')) * pad_width

# I changed this from np.ceil to math.ceil, to yield an integer result.
if pad:
frames = (samples - size + shift + shift - 1) // shift
else:
frames = (samples - size + shift) // shift

frames = (frames > 0) * frames

return frames


Expand Down Expand Up @@ -344,6 +347,8 @@ def sample_index_to_stft_frame_index(sample, window_length, shift, fading='full'

>>> [sample_index_to_stft_frame_index(i, 8, 1, fading=None) for i in range(12)]
[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]
>>> sample_index_to_stft_frame_index(np.arange(12), 8, 1, fading=None).tolist()
[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]
>>> [sample_index_to_stft_frame_index(i, 8, 2, fading=None) for i in range(10)]
[0, 0, 0, 0, 0, 1, 1, 2, 2, 3]
>>> [sample_index_to_stft_frame_index(i, 7, 2, fading=None) for i in range(10)]
Expand Down Expand Up @@ -404,6 +409,8 @@ def stft_frame_index_to_sample_index(

>>> stft_frame_index_to_sample_index(1, 400, 160, mode='first', fading=None)
160
>>> stft_frame_index_to_sample_index(np.ones(1), 400, 160, mode='first', fading=None)
array([160.])
>>> stft_frame_index_to_sample_index(1, 400, 160, mode='center', fading=None)
360
>>> stft_frame_index_to_sample_index(1, 400, 160, mode='last', fading=None)
Expand Down
Loading