diff --git a/gss/bin/modes/enhance.py b/gss/bin/modes/enhance.py index edd711f..d3be8bd 100644 --- a/gss/bin/modes/enhance.py +++ b/gss/bin/modes/enhance.py @@ -36,6 +36,23 @@ def common_options(func): help="Channels to use for enhancement. Specify with comma-separated values, e.g. " "`--channels 0,2,4`. All channels will be used by default.", ) + @click.option( + "--select-channels-by-count", + type=int, + default=None, + help="Number of channels to select for enhancement. If specified, we will use an envelope " + "variance based method to select the best channels. If `--channels` is also specified, we will " + "select the best channels from the specified channels.", + ) + @click.option( + "--select-channels-by-ratio", + type=click.FloatRange(0.0, 1.0), + default=None, + help="Ratio of channels to select for enhancement. If specified, we will use an envelope " + "variance based method to select the best channels. If `--channels` is also specified, we will " + "select the best channels from the specified channels. Note that we will use at least 2 channels " + "and at most the number of channels given by `--channels`.", + ) @click.option( "--bss-iterations", "-i", @@ -157,6 +174,8 @@ def cuts_( cuts_per_segment, enhanced_dir, channels, + select_channels_by_count, + select_channels_by_ratio, bss_iterations, use_wpe, context_duration, @@ -195,6 +214,10 @@ def exit(): atexit.register(exit) + assert not ( + select_channels_by_count and select_channels_by_ratio + ), "Please specify at most one of --select-channels-by-count and --select-channels-by-ratio" + if duration_tolerance is not None: set_audio_duration_mismatch_tolerance(duration_tolerance) @@ -237,6 +260,8 @@ def exit(): num_workers=num_workers, num_buckets=num_buckets, force_overwrite=force_overwrite, + select_channels_by_count=select_channels_by_count, + select_channels_by_ratio=select_channels_by_ratio, ) end = time.time() logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors") @@ -273,6 +298,8 @@ def recording_( enhanced_dir, recording_id, channels, + select_channels_by_count, + select_channels_by_ratio, bss_iterations, use_wpe, context_duration, @@ -310,6 +337,10 @@ def exit(): atexit.register(exit) + assert not ( + select_channels_by_count and select_channels_by_ratio + ), "Please specify at most one of --select-channels-by-count and --select-channels-by-ratio" + enhanced_dir = Path(enhanced_dir) enhanced_dir.mkdir(exist_ok=True, parents=True) @@ -359,6 +390,8 @@ def exit(): num_workers=num_workers, num_buckets=num_buckets, force_overwrite=force_overwrite, + select_channels_by_count=select_channels_by_count, + select_channels_by_ratio=select_channels_by_ratio, ) end = time.time() logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors") diff --git a/gss/core/__init__.py b/gss/core/__init__.py index b4a7321..6d8c826 100644 --- a/gss/core/__init__.py +++ b/gss/core/__init__.py @@ -1,5 +1,6 @@ from .activity import Activity from .beamformer import Beamformer +from .channel_selection import EnvelopeVarianceChannelSelector from .gss import GSS from .stft_module import istft, stft from .wpe import WPE diff --git a/gss/core/channel_selection.py b/gss/core/channel_selection.py new file mode 100644 index 0000000..a4b0d01 --- /dev/null +++ b/gss/core/channel_selection.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass + +import cupy as cp +from lhotse.utils import EPSILON + +from gss.core.stft_module import mel_scale + + +@dataclass +class EnvelopeVarianceChannelSelector: + """ + Envelope Variance Channel Selection method. + """ + + n_mels: int = 40 + n_fft: int = 1024 + hop_length: int = 256 + sampling_rate: int = 16000 + chunk_size: float = 4 + chunk_stride: float = 2 + + def __post_init__(self): + self.subband_weights = cp.ones(self.n_mels) + self.chunk_size = int(self.chunk_size * self.sampling_rate / self.hop_length) + self.chunk_stride = int( + self.chunk_stride * self.sampling_rate / self.hop_length + ) + self.fb = mel_scale( + n_freqs=self.n_fft, n_mels=self.n_mels, sample_rate=self.sampling_rate + ) + + def _single_window(self, mels): + logmels = cp.log(mels + EPSILON) + mels = cp.exp(logmels - cp.mean(logmels, axis=-1, keepdims=True)) + var = cp.var(mels ** (1 / 3), axis=-1) # channels, subbands + var = var / cp.amax(var, axis=1, keepdims=True) + subband_weights = cp.abs(self.subband_weights) + ranking = cp.sum(var * subband_weights, axis=-1) + return ranking + + def _count_chunks(self, inlen, chunk_size, chunk_stride): + return int((inlen - chunk_size + chunk_stride) / chunk_stride) + + def _get_chunks_indx(self, in_len, chunk_size, chunk_stride, discard_last=False): + i = -1 + for i in range(self._count_chunks(in_len, chunk_size, chunk_stride)): + yield i * chunk_stride, i * chunk_stride + chunk_size + if not discard_last and i * chunk_stride + chunk_size < in_len: + if in_len - (i + 1) * chunk_stride > 0: + yield (i + 1) * chunk_stride, in_len + + def __call__(self, obs, num_channels): + """ + Args: + obs: (channels, time, freq) + """ + assert obs.ndim == 3 + + # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time) + mels = cp.matmul(obs, self.fb).transpose(0, 2, 1) + + if mels.shape[-1] > (self.chunk_size + self.chunk_stride): + # using for because i am too lazy of taking care of padded + # values in stats computation, but this is fast + + indxs = self._get_chunks_indx( + mels.shape[-1], self.chunk_size, self.chunk_stride + ) + all_win_ranks = [self._single_window(mels[..., s:t]) for s, t in indxs] + + scores = cp.stack(all_win_ranks, axis=0).mean(axis=0) + else: + scores = self._single_window(mels) + + channel_ranks = cp.argsort(scores)[::-1] + selected_channels = channel_ranks[:num_channels] + return obs[selected_channels], selected_channels diff --git a/gss/core/enhancer.py b/gss/core/enhancer.py index a5581a8..99ecce6 100644 --- a/gss/core/enhancer.py +++ b/gss/core/enhancer.py @@ -11,7 +11,7 @@ from lhotse.utils import add_durations, compute_num_samples from torch.utils.data import DataLoader -from gss.core import GSS, WPE, Activity, Beamformer +from gss.core import GSS, WPE, Activity, Beamformer, EnvelopeVarianceChannelSelector from gss.utils.data_utils import ( GssDataset, activity_time_to_frequency, @@ -62,6 +62,9 @@ def get_enhancer( garbage_class=activity_garbage_class, cuts=cuts, ), + channel_selector=EnvelopeVarianceChannelSelector( + n_fft=stft_size // 2 + 1, hop_length=stft_shift, sampling_rate=sampling_rate + ), gss_block=GSS( iterations=bss_iterations, iterations_post=bss_iterations_post, @@ -86,6 +89,7 @@ class Enhancer: wpe_block: WPE activity: Activity + channel_selector: EnvelopeVarianceChannelSelector gss_block: GSS bf_block: Beamformer @@ -127,6 +131,8 @@ def enhance_cuts( num_buckets=2, num_workers=1, force_overwrite=False, + select_channels_by_count=None, + select_channels_by_ratio=None, ): """ Enhance the given CutSet. @@ -238,6 +244,8 @@ def _save_worker(orig_cuts, x_hat, recording_id, speaker): num_chunks=num_chunks, left_context=batch.left_context, right_context=batch.right_context, + select_channels_by_count=select_channels_by_count, + select_channels_by_ratio=select_channels_by_ratio, ) break except cp.cuda.memory.OutOfMemoryError: @@ -280,7 +288,15 @@ def _save_worker(orig_cuts, x_hat, recording_id, speaker): ) def enhance_batch( - self, obs, activity, speaker_id, num_chunks=1, left_context=0, right_context=0 + self, + obs, + activity, + speaker_id, + num_chunks=1, + left_context=0, + right_context=0, + select_channels_by_count=None, + select_channels_by_ratio=None, ): logging.debug(f"Converting activity to frequency domain") @@ -298,6 +314,18 @@ def enhance_batch( logging.debug(f"Computing STFT") Obs = self.stft(obs) + if select_channels_by_count is not None or select_channels_by_ratio is not None: + if select_channels_by_ratio: + D = Obs.shape[0] + select_channels_by_count = min( + max(2, int(select_channels_by_ratio * D)), D + ) + + Obs, selected_channels = self.channel_selector( + Obs, num_channels=select_channels_by_count + ) + logging.debug(f"Selected channels: {selected_channels}") + D, T, F = Obs.shape # Process observation in chunks diff --git a/gss/core/stft_module.py b/gss/core/stft_module.py index 2be4665..e38d508 100644 --- a/gss/core/stft_module.py +++ b/gss/core/stft_module.py @@ -1,8 +1,9 @@ # The functions here are modified from: # https://github.com/fgnt/paderbox/blob/master/paderbox/transform/module_stft.py +import math import string -import typing -from math import ceil +import warnings +from typing import Optional, Union import cupy as cp import cupyx as cpx @@ -48,7 +49,7 @@ def stft( pad_width = np.zeros([ndim, 2], dtype=np.int) if fading == "half": pad_width[axis, 0] = (window_length - shift) // 2 - pad_width[axis, 1] = ceil((window_length - shift) / 2) + pad_width[axis, 1] = math.ceil((window_length - shift) / 2) else: pad_width[axis, :] = window_length - shift time_signal = cp.pad(time_signal, pad_width, mode="constant") @@ -127,7 +128,7 @@ def istft( size: int = 1024, shift: int = 256, *, - fading: typing.Optional[typing.Union[bool, str]] = "full", + fading: Optional[Union[bool, str]] = "full", ): """ Calculated the inverse short time Fourier transform to exactly reconstruct @@ -180,7 +181,165 @@ def istft( if fading == "half": pad_width /= 2 time_signal = time_signal[ - ..., int(pad_width) : time_signal.shape[-1] - ceil(pad_width) + ..., int(pad_width) : time_signal.shape[-1] - math.ceil(pad_width) ] return time_signal + + +# The following are modified from: +# https://github.com/pytorch/audio/blob/main/torchaudio/functional/functional.py + + +def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float: + r"""Convert Hz to Mels. + Args: + freqs (float): Frequencies in Hz + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + Returns: + mels (float): Frequency in Mels + """ + + if mel_scale not in ["slaney", "htk"]: + raise ValueError('mel_scale should be one of "htk" or "slaney".') + + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + (freq / 700.0)) + + # Fill in the linear part + f_min = 0.0 + f_sp = 200.0 / 3 + + mels = (freq - f_min) / f_sp + + # Fill in the log-scale part + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + + return mels + + +def _mel_to_hz(mels: cp.ndarray, mel_scale: str = "htk") -> cp.ndarray: + """Convert mel bin numbers to frequencies. + Args: + mels (Tensor): Mel frequencies + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + Returns: + freqs (Tensor): Mels converted in Hz + """ + + if mel_scale not in ["slaney", "htk"]: + raise ValueError('mel_scale should be one of "htk" or "slaney".') + + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # And now the nonlinear scale + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * cp.exp(logstep * (mels[log_t] - min_log_mel)) + + return + + +def _create_triangular_filterbank( + all_freqs: cp.ndarray, + f_pts: cp.ndarray, +) -> cp.ndarray: + """Create a triangular filter bank. + Args: + all_freqs (Array): STFT freq points of size (`n_freqs`). + f_pts (Array): Filter mid points of size (`n_filter`). + Returns: + fb (Array): The filter bank of size (`n_freqs`, `n_filter`). + """ + # Adopted from Librosa + # calculate the difference between each filter mid point and each stft freq point in hertz + f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) + slopes = cp.expand_dims(f_pts, axis=0) - cp.expand_dims( + all_freqs, axis=1 + ) # (n_freqs, n_filter + 2) + # create overlapping triangles + zero = cp.zeros(1) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter) + fb = cp.maximum(zero, cp.minimum(down_slopes, up_slopes)) + + return fb + + +def mel_scale( + n_freqs: int, + n_mels: int, + sample_rate: int, + f_min: float = 0.0, + f_max: Optional[float] = None, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> cp.ndarray: + r"""Create a frequency bin conversion matrix. + Note: + For the sake of the numerical compatibility with librosa, not all the coefficients + in the resulting filter bank has magnitude of 1. + .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png + :alt: Visualization of generated filter bank + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * melscale_fbanks(A.size(-1), ...)``. + """ + + if norm is not None and norm != "slaney": + raise ValueError('norm must be one of None or "slaney"') + + # freq bins + all_freqs = cp.linspace(0, sample_rate // 2, n_freqs) + + f_max = f_max or float(sample_rate // 2) + + # calculate mel freq bins + m_min = _hz_to_mel(f_min, mel_scale=mel_scale) + m_max = _hz_to_mel(f_max, mel_scale=mel_scale) + + m_pts = cp.linspace(m_min, m_max, num=n_mels + 2) + f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + fb *= cp.expand_dims(enorm, axis=0) + + if (fb.max(axis=0) == 0.0).any(): + warnings.warn( + "At least one mel filterbank has all zero values. " + f"The value for `n_mels` ({n_mels}) may be set too high. " + f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." + ) + + return fb