From 27d3ac188f72c0dcb6f78a804d96f340ad7453d5 Mon Sep 17 00:00:00 2001 From: popcornell Date: Sun, 19 May 2024 23:04:17 -0400 Subject: [PATCH 1/2] do not chunk when performing beamforming --- gss/core/enhancer.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/gss/core/enhancer.py b/gss/core/enhancer.py index 4f3962f..79f0fcf 100644 --- a/gss/core/enhancer.py +++ b/gss/core/enhancer.py @@ -351,18 +351,9 @@ def enhance_batch( distortion_mask = cp.sum(masks, axis=0) - target_mask logging.debug("Applying beamforming with computed masks") - X_hat = [] - for i in range(num_chunks): - st = i * chunk_size - en = min(F, (i + 1) * chunk_size) - X_hat_chunk = self.bf_block( - Obs[:, :, st:en], - target_mask=target_mask[:, st:en], - distortion_mask=distortion_mask[:, st:en], - ) - X_hat.append(X_hat_chunk) - - X_hat = cp.concatenate(X_hat, axis=1) # freq axis again + X_hat = self.bf_block( + Obs, target_mask=target_mask, distortion_mask=distortion_mask + ) logging.debug("Computing inverse STFT") x_hat = self.istft(X_hat) # returns a numpy array From 6daef2244d0c9a35e7f8d184c0a062d113bf9b37 Mon Sep 17 00:00:00 2001 From: popcornell Date: Sun, 19 May 2024 23:15:09 -0400 Subject: [PATCH 2/2] apply black --- gss/bin/modes/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gss/bin/modes/utils.py b/gss/bin/modes/utils.py index 6d995ec..99d01ad 100644 --- a/gss/bin/modes/utils.py +++ b/gss/bin/modes/utils.py @@ -56,7 +56,9 @@ def rttm_to_supervisions_(rttm_path, out_path, channels): def gpu_check_(num_jobs, cmd): if cmd == "run.pl" and num_jobs > 1: used_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") - assert num_jobs <= len(used_devices), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !" + assert num_jobs <= len( + used_devices + ), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !" for device in used_devices: grep_res = subprocess.check_output(("nvidia-smi", "-i", f"{device}", "-q")) check = re.findall("Compute Mode\s+:\sDefault", str(grep_res))