Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion gss/bin/modes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def rttm_to_supervisions_(rttm_path, out_path, channels):
def gpu_check_(num_jobs, cmd):
if cmd == "run.pl" and num_jobs > 1:
used_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
assert num_jobs <= len(used_devices), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !"
assert num_jobs <= len(
used_devices
), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !"
for device in used_devices:
grep_res = subprocess.check_output(("nvidia-smi", "-i", f"{device}", "-q"))
check = re.findall("Compute Mode\s+:\sDefault", str(grep_res))
Expand Down
15 changes: 3 additions & 12 deletions gss/core/enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,9 @@ def enhance_batch(
distortion_mask = cp.sum(masks, axis=0) - target_mask

logging.debug("Applying beamforming with computed masks")
X_hat = []
for i in range(num_chunks):
st = i * chunk_size
en = min(F, (i + 1) * chunk_size)
X_hat_chunk = self.bf_block(
Obs[:, :, st:en],
target_mask=target_mask[:, st:en],
distortion_mask=distortion_mask[:, st:en],
)
X_hat.append(X_hat_chunk)

X_hat = cp.concatenate(X_hat, axis=1) # freq axis again
X_hat = self.bf_block(
Obs, target_mask=target_mask, distortion_mask=distortion_mask
)

logging.debug("Computing inverse STFT")
x_hat = self.istft(X_hat) # returns a numpy array
Expand Down