Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions gss/bin/modes/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def common_options(func):
help="Channels to use for enhancement. Specify with comma-separated values, e.g. "
"`--channels 0,2,4`. All channels will be used by default.",
)
@click.option(
"--select-channels-by-count",
type=int,
default=None,
help="Number of channels to select for enhancement. If specified, we will use an envelope "
"variance based method to select the best channels. If `--channels` is also specified, we will "
"select the best channels from the specified channels.",
)
@click.option(
"--select-channels-by-ratio",
type=click.FloatRange(0.0, 1.0),
default=None,
help="Ratio of channels to select for enhancement. If specified, we will use an envelope "
"variance based method to select the best channels. If `--channels` is also specified, we will "
"select the best channels from the specified channels. Note that we will use at least 2 channels "
"and at most the number of channels given by `--channels`.",
)
@click.option(
"--bss-iterations",
"-i",
Expand Down Expand Up @@ -157,6 +174,8 @@ def cuts_(
cuts_per_segment,
enhanced_dir,
channels,
select_channels_by_count,
select_channels_by_ratio,
bss_iterations,
use_wpe,
context_duration,
Expand Down Expand Up @@ -195,6 +214,10 @@ def exit():

atexit.register(exit)

assert not (
select_channels_by_count and select_channels_by_ratio
), "Please specify at most one of --select-channels-by-count and --select-channels-by-ratio"

if duration_tolerance is not None:
set_audio_duration_mismatch_tolerance(duration_tolerance)

Expand Down Expand Up @@ -237,6 +260,8 @@ def exit():
num_workers=num_workers,
num_buckets=num_buckets,
force_overwrite=force_overwrite,
select_channels_by_count=select_channels_by_count,
select_channels_by_ratio=select_channels_by_ratio,
)
end = time.time()
logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors")
Expand Down Expand Up @@ -273,6 +298,8 @@ def recording_(
enhanced_dir,
recording_id,
channels,
select_channels_by_count,
select_channels_by_ratio,
bss_iterations,
use_wpe,
context_duration,
Expand Down Expand Up @@ -310,6 +337,10 @@ def exit():

atexit.register(exit)

assert not (
select_channels_by_count and select_channels_by_ratio
), "Please specify at most one of --select-channels-by-count and --select-channels-by-ratio"

enhanced_dir = Path(enhanced_dir)
enhanced_dir.mkdir(exist_ok=True, parents=True)

Expand Down Expand Up @@ -359,6 +390,8 @@ def exit():
num_workers=num_workers,
num_buckets=num_buckets,
force_overwrite=force_overwrite,
select_channels_by_count=select_channels_by_count,
select_channels_by_ratio=select_channels_by_ratio,
)
end = time.time()
logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors")
Expand Down
1 change: 1 addition & 0 deletions gss/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .activity import Activity
from .beamformer import Beamformer
from .channel_selection import EnvelopeVarianceChannelSelector
from .gss import GSS
from .stft_module import istft, stft
from .wpe import WPE
77 changes: 77 additions & 0 deletions gss/core/channel_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from dataclasses import dataclass

import cupy as cp
from lhotse.utils import EPSILON

from gss.core.stft_module import mel_scale


@dataclass
class EnvelopeVarianceChannelSelector:
"""
Envelope Variance Channel Selection method.
"""

n_mels: int = 40
n_fft: int = 1024
hop_length: int = 256
sampling_rate: int = 16000
chunk_size: float = 4
chunk_stride: float = 2

def __post_init__(self):
self.subband_weights = cp.ones(self.n_mels)
self.chunk_size = int(self.chunk_size * self.sampling_rate / self.hop_length)
self.chunk_stride = int(
self.chunk_stride * self.sampling_rate / self.hop_length
)
self.fb = mel_scale(
n_freqs=self.n_fft, n_mels=self.n_mels, sample_rate=self.sampling_rate
)

def _single_window(self, mels):
logmels = cp.log(mels + EPSILON)
mels = cp.exp(logmels - cp.mean(logmels, axis=-1, keepdims=True))
var = cp.var(mels ** (1 / 3), axis=-1) # channels, subbands
var = var / cp.amax(var, axis=1, keepdims=True)
subband_weights = cp.abs(self.subband_weights)
ranking = cp.sum(var * subband_weights, axis=-1)
return ranking

def _count_chunks(self, inlen, chunk_size, chunk_stride):
return int((inlen - chunk_size + chunk_stride) / chunk_stride)

def _get_chunks_indx(self, in_len, chunk_size, chunk_stride, discard_last=False):
i = -1
for i in range(self._count_chunks(in_len, chunk_size, chunk_stride)):
yield i * chunk_stride, i * chunk_stride + chunk_size
if not discard_last and i * chunk_stride + chunk_size < in_len:
if in_len - (i + 1) * chunk_stride > 0:
yield (i + 1) * chunk_stride, in_len

def __call__(self, obs, num_channels):
"""
Args:
obs: (channels, time, freq)
"""
assert obs.ndim == 3

# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
mels = cp.matmul(obs, self.fb).transpose(0, 2, 1)

if mels.shape[-1] > (self.chunk_size + self.chunk_stride):
# using for because i am too lazy of taking care of padded
# values in stats computation, but this is fast

indxs = self._get_chunks_indx(
mels.shape[-1], self.chunk_size, self.chunk_stride
)
all_win_ranks = [self._single_window(mels[..., s:t]) for s, t in indxs]

scores = cp.stack(all_win_ranks, axis=0).mean(axis=0)
else:
scores = self._single_window(mels)

channel_ranks = cp.argsort(scores)[::-1]
selected_channels = channel_ranks[:num_channels]
return obs[selected_channels], selected_channels
32 changes: 30 additions & 2 deletions gss/core/enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lhotse.utils import add_durations, compute_num_samples
from torch.utils.data import DataLoader

from gss.core import GSS, WPE, Activity, Beamformer
from gss.core import GSS, WPE, Activity, Beamformer, EnvelopeVarianceChannelSelector
from gss.utils.data_utils import (
GssDataset,
activity_time_to_frequency,
Expand Down Expand Up @@ -62,6 +62,9 @@ def get_enhancer(
garbage_class=activity_garbage_class,
cuts=cuts,
),
channel_selector=EnvelopeVarianceChannelSelector(
n_fft=stft_size // 2 + 1, hop_length=stft_shift, sampling_rate=sampling_rate
),
gss_block=GSS(
iterations=bss_iterations,
iterations_post=bss_iterations_post,
Expand All @@ -86,6 +89,7 @@ class Enhancer:

wpe_block: WPE
activity: Activity
channel_selector: EnvelopeVarianceChannelSelector
gss_block: GSS
bf_block: Beamformer

Expand Down Expand Up @@ -127,6 +131,8 @@ def enhance_cuts(
num_buckets=2,
num_workers=1,
force_overwrite=False,
select_channels_by_count=None,
select_channels_by_ratio=None,
):
"""
Enhance the given CutSet.
Expand Down Expand Up @@ -238,6 +244,8 @@ def _save_worker(orig_cuts, x_hat, recording_id, speaker):
num_chunks=num_chunks,
left_context=batch.left_context,
right_context=batch.right_context,
select_channels_by_count=select_channels_by_count,
select_channels_by_ratio=select_channels_by_ratio,
)
break
except cp.cuda.memory.OutOfMemoryError:
Expand Down Expand Up @@ -280,7 +288,15 @@ def _save_worker(orig_cuts, x_hat, recording_id, speaker):
)

def enhance_batch(
self, obs, activity, speaker_id, num_chunks=1, left_context=0, right_context=0
self,
obs,
activity,
speaker_id,
num_chunks=1,
left_context=0,
right_context=0,
select_channels_by_count=None,
select_channels_by_ratio=None,
):

logging.debug(f"Converting activity to frequency domain")
Expand All @@ -298,6 +314,18 @@ def enhance_batch(
logging.debug(f"Computing STFT")
Obs = self.stft(obs)

if select_channels_by_count is not None or select_channels_by_ratio is not None:
if select_channels_by_ratio:
D = Obs.shape[0]
select_channels_by_count = min(
max(2, int(select_channels_by_ratio * D)), D
)

Obs, selected_channels = self.channel_selector(
Obs, num_channels=select_channels_by_count
)
logging.debug(f"Selected channels: {selected_channels}")

D, T, F = Obs.shape

# Process observation in chunks
Expand Down
Loading