diff --git a/paderbox/io/download.py b/paderbox/io/download.py index 6f3b7677..d8bdcf1d 100644 --- a/paderbox/io/download.py +++ b/paderbox/io/download.py @@ -1,16 +1,16 @@ import os -import socket -import sys import tarfile import zipfile +import warnings from pathlib import Path from urllib.request import urlretrieve +from concurrent.futures import ProcessPoolExecutor from tqdm import tqdm -def download_file(remote_file, local_file, exist_ok=False): +def download_file(remote_file, local_file, exist_ok=False, extract=False): """ Download single file to local_dir @@ -18,117 +18,89 @@ def download_file(remote_file, local_file, exist_ok=False): remote_file: local_file: exist_ok: + extract: + progress_par: Returns: """ local_file = Path(local_file) if not local_file.exists(): - def progress_hook(t): - """ - https://raw.githubusercontent.com/tqdm/tqdm/master/examples/tqdm_wget.py - - Wraps tqdm instance. Don't forget to close() or __exit__() - the tqdm instance once you're done with it (easiest using - `with` syntax). - """ - - last_b = 0 - - def inner(b=1, bsize=1, tsize=None): - """ - b : int, optional - Number of blocks just transferred [default: 1]. - bsize : int, optional - Size of each block (in tqdm units) [default: 1]. - tsize : int, optional - Total size (in tqdm units). If [default: None] - remains unchanged. - """ - nonlocal last_b - if tsize is not None: - t.total = tsize - t.update((b - last_b) * bsize) - last_b = b - - return inner - tmp_file = str(local_file) + '.tmp' - with tqdm( - desc="{0: >25s}".format(Path(remote_file).stem), - file=sys.stdout, - unit='B', - unit_scale=True, - miniters=1, - leave=False, - ascii=True - ) as t: - urlretrieve( - str(remote_file), - filename=tmp_file, - reporthook=progress_hook(t), - data=None - ) + urlretrieve( + str(remote_file), + filename=tmp_file, + data=None + ) os.rename(tmp_file, local_file) elif not exist_ok: raise FileExistsError(local_file) + if extract: + extract_file(local_file, exist_ok=exist_ok) return local_file -def extract_file(local_file, exist_ok=False): +def extract_file(local_file, target_dir=None, exist_ok=False): """ If local_file is .zip or .tar.gz files are extracted. Args: local_file: + target_dir: exist_ok: Returns: """ local_file = Path(local_file) - local_dir = local_file.parent - if local_file.exists(): - - if local_file.name.endswith('.zip'): - with zipfile.ZipFile(local_file, "r") as z: - # Start extraction - members = z.infolist() - for i, member in enumerate(members): - target_file = local_dir / member.filename - if not target_file.exists(): - try: - z.extract(member=member, path=local_dir) - except KeyboardInterrupt: - # Delete latest file, since most likely it - # was not extracted fully - if target_file.exists(): - os.remove(target_file) - raise - elif not exist_ok: - raise FileExistsError(target_file) - os.remove(local_file) - - elif local_file.name.endswith('.tar.gz'): - with tarfile.open(local_file, "r:gz") as tar: - for i, tar_info in enumerate(tar): - target_file = local_dir / tar_info.name - if not target_file.exists(): - try: - tar.extract(tar_info, local_dir) - except KeyboardInterrupt: - # Delete latest file, since most likely it - # was not extracted fully - if target_file.exists(): - os.remove(target_file) - raise - elif not exist_ok: - raise FileExistsError(target_file) - tar.members = [] - os.remove(local_file) - - -def download_file_list(file_list, target_dir, exist_ok=False, logger=None): + assert local_file.exists(), local_file + if target_dir is None: + target_dir = local_file.parent + else: + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + if local_file.name.endswith('.zip'): + with zipfile.ZipFile(local_file, "r") as z: + # Start extraction + members = z.infolist() + for i, member in enumerate(members): + target_file = target_dir / member.filename + if not target_file.exists(): + try: + z.extract(member=member, path=target_dir) + except KeyboardInterrupt: + # Delete latest file, since most likely it + # was not extracted fully + if target_file.exists(): + os.remove(target_file) + raise + elif not exist_ok: + raise FileExistsError(target_file) + os.remove(local_file) + + elif local_file.name.endswith('.tar.gz') or local_file.name.endswith('.tar'): + mode = "r:gz" if local_file.name.endswith('.tar.gz') else "r" + with tarfile.open(local_file, mode) as tar: + for i, tar_info in enumerate(tar): + target_file = target_dir / tar_info.name + if not target_file.exists(): + try: + tar.extract(tar_info, target_dir) + except KeyboardInterrupt: + # Delete latest file, since most likely it + # was not extracted fully + if target_file.exists(): + os.remove(target_file) + raise + elif not exist_ok: + raise FileExistsError(target_file) + tar.members = [] + os.remove(local_file) + else: + warnings.warn("Unsupported file format: Cannot extract file.") + + +def download_file_list(file_list, target_dir, extract=True, exist_ok=False, num_workers=1): """ Download file_list to target_dir @@ -136,49 +108,36 @@ def download_file_list(file_list, target_dir, exist_ok=False, logger=None): file_list: target_dir: exist_ok: - logger: + extract: + num_workers: Returns: """ - target_dir = Path(target_dir) os.makedirs(target_dir, exist_ok=True) - item_progress = tqdm( - file_list, desc="{0: <25s}".format('Download files'), - file=sys.stdout, leave=False, ascii=True) - - local_files = list() - for remote_file in item_progress: - local_files.append( - download_file( - remote_file, - target_dir / Path(remote_file).name, - exist_ok=exist_ok - ) - ) - - item_progress = tqdm( - local_files, - desc="{0: <25s}".format('Extract files'), - file=sys.stdout, - leave=False, - ascii=True - ) - - if logger is not None: - logger.info('Starting Extraction') - for _id, local_file in enumerate(item_progress): - if local_file and local_file.exists(): - if logger is not None: - logger.info( - ' {title:<15s} [{item_id:d}/{total:d}] {package:<30s}' - .format( - title='Extract files ', - item_id=_id, - total=len(item_progress), - package=local_file - ) - ) - extract_file(local_file, exist_ok=exist_ok) + pbar = tqdm(initial=0, total=len(file_list)) + + if isinstance(extract, bool): + extract = len(file_list) * [extract] + assert len(extract) == len(file_list), (len(extract), len(file_list)) + if num_workers > 1: + with ProcessPoolExecutor(num_workers) as ex: + for _ in ex.map( + download_file, + file_list, + [target_dir / Path(f).name.split('?')[0] for f in file_list], # extract file names from urls discarding query strings + len(file_list) * [exist_ok], + extract, + ): + pbar.update(1) + else: + for _ in map( + download_file, + file_list, + [target_dir / Path(f).name.split('?')[0] for f in file_list], + len(file_list) * [exist_ok], + extract, + ): + pbar.update(1) diff --git a/paderbox/transform/module_resample.py b/paderbox/transform/module_resample.py index abfca548..d74155f9 100644 --- a/paderbox/transform/module_resample.py +++ b/paderbox/transform/module_resample.py @@ -97,7 +97,7 @@ def resample_sox(signal: np.ndarray, *, in_rate, out_rate, normalize=True): # input signal is much too large. # We normalize each channel independently to avoid rounding errors leading # to the channel doc test above to fail randomly. - normalizer = 0.95 / np.max(np.abs(signal), keepdims=True, axis=-1) + normalizer = 0.95 / np.maximum(np.max(np.abs(signal), keepdims=True, axis=-1), 1e-12) signal = normalizer * signal sox_type = { diff --git a/paderbox/transform/module_stft.py b/paderbox/transform/module_stft.py index f82629c0..34d12538 100644 --- a/paderbox/transform/module_stft.py +++ b/paderbox/transform/module_stft.py @@ -234,6 +234,8 @@ def _samples_to_stft_frames( 2 >>> _samples_to_stft_frames(21, 16, 4) 3 + >>> _samples_to_stft_frames(0, 16, 4) + 0 >>> stft(np.zeros(19), 16, 4, fading=None).shape (2, 9) @@ -275,12 +277,13 @@ def _samples_to_stft_frames( pad_width = (size - shift) samples = samples + (1 + (fading != 'half')) * pad_width - # I changed this from np.ceil to math.ceil, to yield an integer result. if pad: frames = (samples - size + shift + shift - 1) // shift else: frames = (samples - size + shift) // shift + frames = (frames > 0) * frames + return frames @@ -344,6 +347,8 @@ def sample_index_to_stft_frame_index(sample, window_length, shift, fading='full' >>> [sample_index_to_stft_frame_index(i, 8, 1, fading=None) for i in range(12)] [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8] + >>> sample_index_to_stft_frame_index(np.arange(12), 8, 1, fading=None).tolist() + [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8] >>> [sample_index_to_stft_frame_index(i, 8, 2, fading=None) for i in range(10)] [0, 0, 0, 0, 0, 1, 1, 2, 2, 3] >>> [sample_index_to_stft_frame_index(i, 7, 2, fading=None) for i in range(10)] @@ -404,6 +409,8 @@ def stft_frame_index_to_sample_index( >>> stft_frame_index_to_sample_index(1, 400, 160, mode='first', fading=None) 160 + >>> stft_frame_index_to_sample_index(np.ones(1), 400, 160, mode='first', fading=None) + array([160.]) >>> stft_frame_index_to_sample_index(1, 400, 160, mode='center', fading=None) 360 >>> stft_frame_index_to_sample_index(1, 400, 160, mode='last', fading=None) diff --git a/paderbox/utils/random_utils.py b/paderbox/utils/random_utils.py index 5639bccc..dbfe2afa 100644 --- a/paderbox/utils/random_utils.py +++ b/paderbox/utils/random_utils.py @@ -12,6 +12,7 @@ 'truncated_normal', 'log_truncated_normal', 'truncated_exponential', + 'choice', 'hermitian', 'pos_def_hermitian', 'Uniform', @@ -20,6 +21,7 @@ 'TruncatedNormal', 'LogTruncatedNormal', 'TruncatedExponential', + 'Choice' ] @@ -177,6 +179,21 @@ def _sample(self, shape): return truncexpon(self.truncation / self.scale, self.loc, self.scale).rvs(shape) +@dataclasses.dataclass +class Choice(_Sampler): + events: int = None + replace: bool = True + p: list = None + + def __post_init__(self): + super().__post_init__() + if self.events is None: + raise TypeError('Missing required argument events') + + def _sample(self, shape): + return np.random.choice(self.events, size=shape, replace=self.replace, p=self.p) + + def uniform(*shape, low=0., high=1., dtype=np.float64): """ @@ -425,6 +442,39 @@ def truncated_exponential(*shape, loc=0., scale=1., truncation=3., dtype=np.floa return TruncatedExponential(loc=loc, scale=scale, truncation=truncation, dtype=dtype)(*shape) +def choice(*shape, events, replace=True, p=None, dtype=np.float64): + """ + + Args: + *shape: + events: + replace: + dtype: + + Returns: + + >>> x = choice(events=2) + >>> x.ndim + 0 + >>> x.dtype + dtype('float64') + >>> x = choice(2, 3, events=2) + >>> x.shape, x.dtype + ((2, 3), dtype('float64')) + >>> x = choice(2, 3, events=2, dtype=np.complex128) + >>> x.shape, x.dtype + ((2, 3), dtype('complex128')) + >>> np.random.seed(2) + >>> x = choice(2, 3, events=[2, 4], p=[.1,.9]) + >>> x.shape, x.dtype + ((2, 3), dtype('float64')) + >>> x + array([[4., 2., 4.], + [4., 4., 4.]]) + """ + return Choice(events=events, replace=replace, p=p, dtype=dtype)(*shape) + + def hermitian(*shape, dtype=np.complex128): """ Assures a random positive-semidefinite hermitian matrix.