diff --git a/gss/bin/modes/utils.py b/gss/bin/modes/utils.py index 6d995ec..99d01ad 100644 --- a/gss/bin/modes/utils.py +++ b/gss/bin/modes/utils.py @@ -56,7 +56,9 @@ def rttm_to_supervisions_(rttm_path, out_path, channels): def gpu_check_(num_jobs, cmd): if cmd == "run.pl" and num_jobs > 1: used_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") - assert num_jobs <= len(used_devices), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !" + assert num_jobs <= len( + used_devices + ), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !" for device in used_devices: grep_res = subprocess.check_output(("nvidia-smi", "-i", f"{device}", "-q")) check = re.findall("Compute Mode\s+:\sDefault", str(grep_res)) diff --git a/gss/core/enhancer.py b/gss/core/enhancer.py index 4f3962f..79f0fcf 100644 --- a/gss/core/enhancer.py +++ b/gss/core/enhancer.py @@ -351,18 +351,9 @@ def enhance_batch( distortion_mask = cp.sum(masks, axis=0) - target_mask logging.debug("Applying beamforming with computed masks") - X_hat = [] - for i in range(num_chunks): - st = i * chunk_size - en = min(F, (i + 1) * chunk_size) - X_hat_chunk = self.bf_block( - Obs[:, :, st:en], - target_mask=target_mask[:, st:en], - distortion_mask=distortion_mask[:, st:en], - ) - X_hat.append(X_hat_chunk) - - X_hat = cp.concatenate(X_hat, axis=1) # freq axis again + X_hat = self.bf_block( + Obs, target_mask=target_mask, distortion_mask=distortion_mask + ) logging.debug("Computing inverse STFT") x_hat = self.istft(X_hat) # returns a numpy array