From ebbbcbcaf165ffd95e1d7df4fc873e356e8aacd9 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 4 Sep 2024 14:27:25 +0800 Subject: [PATCH 1/7] support consistency-regularized CTC --- egs/librispeech/ASR/zipformer/model.py | 171 +++++++++- egs/librispeech/ASR/zipformer/spec_augment.py | 313 ++++++++++++++++++ egs/librispeech/ASR/zipformer/train.py | 85 ++++- 3 files changed, 556 insertions(+), 13 deletions(-) create mode 100644 egs/librispeech/ASR/zipformer/spec_augment.py diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8d..cf935d8351 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,6 +25,7 @@ from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask +from spec_augment import SpecAugment, time_warp class AsrModel(nn.Module): @@ -181,6 +182,63 @@ def forward_ctc( ) return ctc_loss + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + time_mask: Optional[torch.Tensor] = None, + cr_loss_masked_scale: float = 3.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + time_mask: + Downsampled time masks of shape (2 * N, T, 1). + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + if time_mask is not None: + assert time_mask.shape[:-1] == ctc_output.shape[:-1], ( + time_mask.shape, ctc_output.shape + ) + masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1 + # e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3, + # scales at unmasked positions are 1 + cr_loss = cr_loss * masked_scale # scaling up masked positions + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + def forward_transducer( self, encoder_out: torch.Tensor, @@ -296,7 +354,13 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + cr_loss_masked_scale: float = 3.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +380,28 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +417,27 @@ def forward( device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x, time_mask = spec_augment(x.repeat(2, 1, 1)) + # time_mask: 1 for masked, 0 for unmasked + time_mask = downsample_time_mask(time_mask, x.dtype) + else: + x = x.repeat(2, 1, 1) + time_mask = None + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +455,9 @@ def forward( am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +465,28 @@ def forward( if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + time_mask=time_mask, + cr_loss_masked_scale=cr_loss_masked_scale, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +495,37 @@ def forward( ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + + +def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype): + """Downsample the time masks as in Zipformer. + Args: + time_mask: shape of (N, T) + Returns: + The downsampled time masks of shape (N, T', 1), + where T' = ((T - 7) // 2 + 1) // 2 + """ + # Downsample the time masks as in Zipformer + time_mask = time_mask.to(dtype).unsqueeze(dim=1) + # as in conv-embed + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # T - 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=2, padding=0 + ) # (T - 3) // 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # (T - 7) // 2 + # as in output-downsampling + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1) + return time_mask diff --git a/egs/librispeech/ASR/zipformer/spec_augment.py b/egs/librispeech/ASR/zipformer/spec_augment.py new file mode 100644 index 0000000000..6ddf2b09bd --- /dev/null +++ b/egs/librispeech/ASR/zipformer/spec_augment.py @@ -0,0 +1,313 @@ +# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +# with minor modification for cr-ctc training. + + +import math +import random +from typing import Any, Dict, Optional, Tuple + +import torch +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl + + +class SpecAugment(torch.nn.Module): + """SpecAugment from lhotse with minor modification, returning time masks. + + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 2, + features_mask_size: int = 27, + num_frame_masks: int = 10, + frames_mask_size: int = 100, + max_frames_mask_fraction: float = 0.15, + p=0.9, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: + - an augmented tensor of shape ``(B, T, F)``. + - the corresponding time masks of shape ``(B, T)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + + time_masks = [] + + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single(features[sequence_idx]) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + warped_feature, _ = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + features[sequence_idx, start_frame:end_frame] = warped_feature + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + + time_masks = torch.cat(time_masks, dim=0) + assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1]) + return features, time_masks + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + return features, time_mask + + time_mask = None + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp_impl(features, factor=self.time_warp_factor) + + if mask: + mean = features.mean() + # Frequency masking + features, _ = mask_along_axis_optimized( + features, + mask_size=self.features_mask_size, + mask_times=self.num_feature_masks, + mask_value=mean, + axis=2, + ) + # Time masking + max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min( + self.num_frame_masks, + math.ceil(max_tot_mask_frames / self.frames_mask_size), + ) + max_mask_frames = min( + self.frames_mask_size, max_tot_mask_frames // num_frame_masks + ) + features, time_mask = mask_along_axis_optimized( + features, + mask_size=max_mask_frames, + mask_times=num_frame_masks, + mask_value=mean, + axis=1, + return_time_mask=True, + ) + + return features, time_mask + + def state_dict(self, **kwargs) -> Dict[str, Any]: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def mask_along_axis_optimized( + features: torch.Tensor, + mask_size: int, + mask_times: int, + mask_value: float, + axis: int, + return_time_mask: bool = False, +) -> torch.Tensor: + """ + Apply Frequency and Time masking along axis. + Frequency and Time masking as described in the SpecAugment paper. + + :param features: input tensor of shape ``(T, F)`` + :mask_size: the width size for masking. + :mask_times: the number of masking regions. + :mask_value: Value to assign to the masked regions. + :axis: Axis to apply masking on (1 -> time, 2 -> frequency) + :return_time_mask: Whether return the time mask of shape ``(1, T)`` + """ + if axis not in [1, 2]: + raise ValueError("Only Frequency and Time masking are supported!") + + if return_time_mask and axis == 1: + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + else: + time_mask = None + + features = features.unsqueeze(0) + features = features.reshape([-1] + list(features.size()[-2:])) + + values = torch.randint(int(0), int(mask_size), (1, mask_times)) + min_values = torch.rand(1, mask_times) * (features.size(axis) - values) + mask_starts = (min_values.long()).squeeze() + mask_ends = (min_values.long() + values.long()).squeeze() + + if axis == 1: + if mask_times == 1: + features[:, mask_starts:mask_ends] = mask_value + if return_time_mask: + time_mask[:, mask_starts:mask_ends] = True + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, mask_start:mask_end] = mask_value + if return_time_mask: + time_mask[:, mask_start:mask_end] = True + else: + if mask_times == 1: + features[:, :, mask_starts:mask_ends] = mask_value + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, :, mask_start:mask_end] = mask_value + + features = features.squeeze(0) + return features, time_mask + + +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a78..328b3cfdd3 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -102,6 +102,7 @@ setup_logger, str2bool, ) +from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -304,6 +305,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +457,13 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.1, + help="Scale for consistency-regularization loss.", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -590,6 +605,11 @@ def get_params() -> AttributeDict: # parameters for attention-decoder "ignore_id": -1, "label_smoothing": 0.1, + # parameters used for CR-CTC + # When using cr-ctc, we increase the time-masking ratio. + "time_mask_ratio": 2.0, + # The scale used to scale up the cr_loss at masked positions. + "cr_loss_masked_scale": 3.0, "warm_step": 2000, "env_info": get_env_info(), } @@ -717,6 +737,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = 10 * params.time_mask_ratio + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +877,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +894,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +913,35 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + cr_loss_masked_scale=params.cr_loss_masked_scale, ) loss = 0.0 @@ -904,6 +964,8 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +984,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1035,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1062,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1110,7 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1306,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1435,7 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1463,7 @@ def remove_short_and_long_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() From 07d6b123643f59377325f1611967f72505eaea8f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 4 Sep 2024 19:33:41 +0800 Subject: [PATCH 2/7] update arguments of cr-ctc --- egs/librispeech/ASR/zipformer/train.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 328b3cfdd3..c2aaf8b568 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -460,10 +460,24 @@ def get_parser(): parser.add_argument( "--cr-loss-scale", type=float, - default=0.1, + default=0.15, help="Scale for consistency-regularization loss.", ) + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.0, + help="When using cr-ctc, we increase the time-masking ratio.", + ) + + parser.add_argument( + "--cr-loss-masked-scale", + type=float, + default=1.0, + help="The value used to scale up the cr_loss at masked positions", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -605,11 +619,6 @@ def get_params() -> AttributeDict: # parameters for attention-decoder "ignore_id": -1, "label_smoothing": 0.1, - # parameters used for CR-CTC - # When using cr-ctc, we increase the time-masking ratio. - "time_mask_ratio": 2.0, - # The scale used to scale up the cr_loss at masked positions. - "cr_loss_masked_scale": 3.0, "warm_step": 2000, "env_info": get_env_info(), } From cf796eefed40fe211b22b4381a4166915a76f11f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 Sep 2024 10:32:36 +0800 Subject: [PATCH 3/7] set default value of cr_loss_masked_scale to 1.0 --- egs/librispeech/ASR/zipformer/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index cf935d8351..2de1e08fee 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -189,7 +189,7 @@ def forward_cr_ctc( targets: torch.Tensor, target_lengths: torch.Tensor, time_mask: Optional[torch.Tensor] = None, - cr_loss_masked_scale: float = 3.0, + cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute CTC loss with consistency regularization loss. Args: @@ -359,7 +359,7 @@ def forward( spec_augment: Optional[SpecAugment] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, - cr_loss_masked_scale: float = 3.0, + cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: From a6eead6c982de726cceea8cb07a90e9bd18f2070 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 9 Sep 2024 10:10:08 +0800 Subject: [PATCH 4/7] minor fix --- egs/librispeech/ASR/zipformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c2aaf8b568..3fde55de24 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -747,7 +747,7 @@ def get_model(params: AttributeDict) -> nn.Module: def get_spec_augment(params: AttributeDict) -> SpecAugment: - num_frame_masks = 10 * params.time_mask_ratio + num_frame_masks = int(10 * params.time_mask_ratio) max_frames_mask_fraction = 0.15 * params.time_mask_ratio logging.info( f"num_frame_masks: {num_frame_masks}, " From ae59e5d61ef1dd52ba3fd81efe6504164f661c51 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 8 Oct 2024 00:34:32 +0800 Subject: [PATCH 5/7] refactor codes --- egs/librispeech/ASR/zipformer/model.py | 56 +--- egs/librispeech/ASR/zipformer/spec_augment.py | 313 ------------------ egs/librispeech/ASR/zipformer/train.py | 16 +- icefall/utils.py | 40 +++ 4 files changed, 47 insertions(+), 378 deletions(-) delete mode 100644 egs/librispeech/ASR/zipformer/spec_augment.py diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 2de1e08fee..deebb2a754 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -24,8 +24,8 @@ from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask -from spec_augment import SpecAugment, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment class AsrModel(nn.Module): @@ -188,8 +188,6 @@ def forward_cr_ctc( encoder_out_lens: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, - time_mask: Optional[torch.Tensor] = None, - cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute CTC loss with consistency regularization loss. Args: @@ -200,10 +198,6 @@ def forward_cr_ctc( targets: Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension. - time_mask: - Downsampled time masks of shape (2 * N, T, 1). - cr_loss_masked_scale: - The loss scale used to scale up the cr_loss at masked positions. """ # Compute CTC loss ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) @@ -226,14 +220,6 @@ def forward_cr_ctc( reduction="none", log_target=True, ) # (2 * N, T, C) - if time_mask is not None: - assert time_mask.shape[:-1] == ctc_output.shape[:-1], ( - time_mask.shape, ctc_output.shape - ) - masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1 - # e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3, - # scales at unmasked positions are 1 - cr_loss = cr_loss * masked_scale # scaling up masked positions length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() @@ -359,7 +345,6 @@ def forward( spec_augment: Optional[SpecAugment] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, - cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -395,8 +380,6 @@ def forward( Parameter for the time warping; larger values mean more warping. Set to ``None``, or less than ``1``, to disable. Used only if use_cr_ctc is True. - cr_loss_masked_scale: - The loss scale used to scale up the cr_loss at masked positions. Returns: Return the transducer losses, CTC loss, AED loss, @@ -429,12 +412,9 @@ def forward( supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x, time_mask = spec_augment(x.repeat(2, 1, 1)) - # time_mask: 1 for masked, 0 for unmasked - time_mask = downsample_time_mask(time_mask, x.dtype) + x = spec_augment(x.repeat(2, 1, 1)) else: x = x.repeat(2, 1, 1) - time_mask = None x_lens = x_lens.repeat(2) y = k2.ragged.cat([y, y], axis=0) @@ -479,8 +459,6 @@ def forward( encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - time_mask=time_mask, - cr_loss_masked_scale=cr_loss_masked_scale, ) ctc_loss = ctc_loss * 0.5 cr_loss = cr_loss * 0.5 @@ -501,31 +479,3 @@ def forward( attention_decoder_loss = torch.empty(0) return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss - - -def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype): - """Downsample the time masks as in Zipformer. - Args: - time_mask: shape of (N, T) - Returns: - The downsampled time masks of shape (N, T', 1), - where T' = ((T - 7) // 2 + 1) // 2 - """ - # Downsample the time masks as in Zipformer - time_mask = time_mask.to(dtype).unsqueeze(dim=1) - # as in conv-embed - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=1, padding=0 - ) # T - 2 - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=2, padding=0 - ) # (T - 3) // 2 - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=1, padding=0 - ) # (T - 7) // 2 - # as in output-downsampling - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1) - return time_mask diff --git a/egs/librispeech/ASR/zipformer/spec_augment.py b/egs/librispeech/ASR/zipformer/spec_augment.py deleted file mode 100644 index 6ddf2b09bd..0000000000 --- a/egs/librispeech/ASR/zipformer/spec_augment.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py -# with minor modification for cr-ctc training. - - -import math -import random -from typing import Any, Dict, Optional, Tuple - -import torch -from lhotse.dataset.signal_transforms import time_warp as time_warp_impl - - -class SpecAugment(torch.nn.Module): - """SpecAugment from lhotse with minor modification, returning time masks. - - SpecAugment performs three augmentations: - - time warping of the feature matrix - - masking of ranges of features (frequency bands) - - masking of ranges of frames (time) - - The current implementation works with batches, but processes each example separately - in a loop rather than simultaneously to achieve different augmentation parameters for - each example. - """ - - def __init__( - self, - time_warp_factor: Optional[int] = 80, - num_feature_masks: int = 2, - features_mask_size: int = 27, - num_frame_masks: int = 10, - frames_mask_size: int = 100, - max_frames_mask_fraction: float = 0.15, - p=0.9, - ): - """ - SpecAugment's constructor. - - :param time_warp_factor: parameter for the time warping; larger values mean more warping. - Set to ``None``, or less than ``1``, to disable. - :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. - :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). - This is the ``F`` parameter from the SpecAugment paper. - :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. - :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). - This is the ``T`` parameter from the SpecAugment paper. - :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length - of the utterance (or supervision segment). - This is the parameter denoted by ``p`` in the SpecAugment paper. - :param p: the probability of applying this transform. - It is different from ``p`` in the SpecAugment paper! - """ - super().__init__() - assert 0 <= p <= 1 - assert num_feature_masks >= 0 - assert num_frame_masks >= 0 - assert features_mask_size > 0 - assert frames_mask_size > 0 - self.time_warp_factor = time_warp_factor - self.num_feature_masks = num_feature_masks - self.features_mask_size = features_mask_size - self.num_frame_masks = num_frame_masks - self.frames_mask_size = frames_mask_size - self.max_frames_mask_fraction = max_frames_mask_fraction - self.p = p - - def forward( - self, - features: torch.Tensor, - supervision_segments: Optional[torch.IntTensor] = None, - *args, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Computes SpecAugment for a batch of feature matrices. - - Since the batch will usually already be padded, the user can optionally - provide a ``supervision_segments`` tensor that will be used to apply SpecAugment - only to selected areas of the input. The format of this input is described below. - - :param features: a batch of feature matrices with shape ``(B, T, F)``. - :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of - supervision segments that exist in ``features`` -- there may be either - less or more than the batch size. - The second dimension encoder three kinds of information: - the sequence index of the corresponding feature matrix in `features`, - the start frame index, and the number of frames for each segment. - :return: - - an augmented tensor of shape ``(B, T, F)``. - - the corresponding time masks of shape ``(B, T)``. - """ - assert len(features.shape) == 3, ( - "SpecAugment only supports batches of " "single-channel feature matrices." - ) - features = features.clone() - - time_masks = [] - - if supervision_segments is None: - # No supervisions - apply spec augment to full feature matrices. - for sequence_idx in range(features.size(0)): - masked_feature, time_mask = self._forward_single(features[sequence_idx]) - features[sequence_idx] = masked_feature - time_masks.append(time_mask) - else: - # Supervisions provided - we will apply time warping only on the supervised areas. - for sequence_idx, start_frame, num_frames in supervision_segments: - end_frame = start_frame + num_frames - warped_feature, _ = self._forward_single( - features[sequence_idx, start_frame:end_frame], warp=True, mask=False - ) - features[sequence_idx, start_frame:end_frame] = warped_feature - # ... and then time-mask the full feature matrices. Note that in this mode, - # it might happen that masks are applied to different sequences/examples - # than the time warping. - for sequence_idx in range(features.size(0)): - masked_feature, time_mask = self._forward_single( - features[sequence_idx], warp=False, mask=True - ) - features[sequence_idx] = masked_feature - time_masks.append(time_mask) - - time_masks = torch.cat(time_masks, dim=0) - assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1]) - return features, time_masks - - def _forward_single( - self, features: torch.Tensor, warp: bool = True, mask: bool = True - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply SpecAugment to a single feature matrix of shape (T, F). - """ - if random.random() > self.p: - # Randomly choose whether this transform is applied - time_mask = torch.zeros( - 1, features.size(0), dtype=torch.bool, device=features.device - ) - return features, time_mask - - time_mask = None - if warp: - if self.time_warp_factor is not None and self.time_warp_factor >= 1: - features = time_warp_impl(features, factor=self.time_warp_factor) - - if mask: - mean = features.mean() - # Frequency masking - features, _ = mask_along_axis_optimized( - features, - mask_size=self.features_mask_size, - mask_times=self.num_feature_masks, - mask_value=mean, - axis=2, - ) - # Time masking - max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) - num_frame_masks = min( - self.num_frame_masks, - math.ceil(max_tot_mask_frames / self.frames_mask_size), - ) - max_mask_frames = min( - self.frames_mask_size, max_tot_mask_frames // num_frame_masks - ) - features, time_mask = mask_along_axis_optimized( - features, - mask_size=max_mask_frames, - mask_times=num_frame_masks, - mask_value=mean, - axis=1, - return_time_mask=True, - ) - - return features, time_mask - - def state_dict(self, **kwargs) -> Dict[str, Any]: - return dict( - time_warp_factor=self.time_warp_factor, - num_feature_masks=self.num_feature_masks, - features_mask_size=self.features_mask_size, - num_frame_masks=self.num_frame_masks, - frames_mask_size=self.frames_mask_size, - max_frames_mask_fraction=self.max_frames_mask_fraction, - p=self.p, - ) - - def load_state_dict(self, state_dict: Dict[str, Any]): - self.time_warp_factor = state_dict.get( - "time_warp_factor", self.time_warp_factor - ) - self.num_feature_masks = state_dict.get( - "num_feature_masks", self.num_feature_masks - ) - self.features_mask_size = state_dict.get( - "features_mask_size", self.features_mask_size - ) - self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) - self.frames_mask_size = state_dict.get( - "frames_mask_size", self.frames_mask_size - ) - self.max_frames_mask_fraction = state_dict.get( - "max_frames_mask_fraction", self.max_frames_mask_fraction - ) - self.p = state_dict.get("p", self.p) - - -def mask_along_axis_optimized( - features: torch.Tensor, - mask_size: int, - mask_times: int, - mask_value: float, - axis: int, - return_time_mask: bool = False, -) -> torch.Tensor: - """ - Apply Frequency and Time masking along axis. - Frequency and Time masking as described in the SpecAugment paper. - - :param features: input tensor of shape ``(T, F)`` - :mask_size: the width size for masking. - :mask_times: the number of masking regions. - :mask_value: Value to assign to the masked regions. - :axis: Axis to apply masking on (1 -> time, 2 -> frequency) - :return_time_mask: Whether return the time mask of shape ``(1, T)`` - """ - if axis not in [1, 2]: - raise ValueError("Only Frequency and Time masking are supported!") - - if return_time_mask and axis == 1: - time_mask = torch.zeros( - 1, features.size(0), dtype=torch.bool, device=features.device - ) - else: - time_mask = None - - features = features.unsqueeze(0) - features = features.reshape([-1] + list(features.size()[-2:])) - - values = torch.randint(int(0), int(mask_size), (1, mask_times)) - min_values = torch.rand(1, mask_times) * (features.size(axis) - values) - mask_starts = (min_values.long()).squeeze() - mask_ends = (min_values.long() + values.long()).squeeze() - - if axis == 1: - if mask_times == 1: - features[:, mask_starts:mask_ends] = mask_value - if return_time_mask: - time_mask[:, mask_starts:mask_ends] = True - return features.squeeze(0), time_mask - for (mask_start, mask_end) in zip(mask_starts, mask_ends): - features[:, mask_start:mask_end] = mask_value - if return_time_mask: - time_mask[:, mask_start:mask_end] = True - else: - if mask_times == 1: - features[:, :, mask_starts:mask_ends] = mask_value - return features.squeeze(0), time_mask - for (mask_start, mask_end) in zip(mask_starts, mask_ends): - features[:, :, mask_start:mask_end] = mask_value - - features = features.squeeze(0) - return features, time_mask - - -def time_warp( - features: torch.Tensor, - p: float = 0.9, - time_warp_factor: Optional[int] = 80, - supervision_segments: Optional[torch.Tensor] = None, -): - if time_warp_factor is None or time_warp_factor < 1: - return features - assert len(features.shape) == 3, ( - "SpecAugment only supports batches of single-channel feature matrices." - ) - features = features.clone() - if supervision_segments is None: - # No supervisions - apply spec augment to full feature matrices. - for sequence_idx in range(features.size(0)): - if random.random() > p: - # Randomly choose whether this transform is applied - continue - features[sequence_idx] = time_warp_impl( - features[sequence_idx], factor=time_warp_factor - ) - else: - # Supervisions provided - we will apply time warping only on the supervised areas. - for sequence_idx, start_frame, num_frames in supervision_segments: - if random.random() > p: - # Randomly choose whether this transform is applied - continue - end_frame = start_frame + num_frames - features[sequence_idx, start_frame:end_frame] = time_warp_impl( - features[sequence_idx, start_frame:end_frame], factor=time_warp_factor - ) - - return features diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3fde55de24..3a8995c811 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -72,6 +72,7 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -102,7 +103,6 @@ setup_logger, str2bool, ) -from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -460,22 +460,15 @@ def get_parser(): parser.add_argument( "--cr-loss-scale", type=float, - default=0.15, + default=0.2, help="Scale for consistency-regularization loss.", ) parser.add_argument( "--time-mask-ratio", type=float, - default=2.0, - help="When using cr-ctc, we increase the time-masking ratio.", - ) - - parser.add_argument( - "--cr-loss-masked-scale", - type=float, - default=1.0, - help="The value used to scale up the cr_loss at masked positions", + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", ) parser.add_argument( @@ -950,7 +943,6 @@ def compute_loss( spec_augment=spec_augment, supervision_segments=supervision_segments, time_warp_factor=params.spec_aug_time_warp_factor, - cr_loss_masked_scale=params.cr_loss_masked_scale, ) loss = 0.0 diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954ded..b0a42cefaa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,6 +21,7 @@ import collections import logging import os +import random import re import subprocess from collections import defaultdict @@ -38,6 +39,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -2271,3 +2273,41 @@ def num_tokens( if 0 in ans: num_tokens -= 1 return num_tokens + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + """Apply time warping on a batch of features + """ + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features From b65873fb4cc5658413d75b449054d4c0cc83e8ee Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 20 Oct 2024 17:48:57 +0800 Subject: [PATCH 6/7] update RESULTS.md --- egs/librispeech/ASR/README.md | 8 +- egs/librispeech/ASR/RESULTS.md | 310 +++++++++++++++++++++++++ egs/librispeech/ASR/zipformer/train.py | 9 +- 3 files changed, 321 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 8b87ee19b4..0dbfdc931c 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `zipformer-ctc` | Zipformer | Use auxiliary attention head | -| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | +| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) | # MMI @@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer. |------------------------------|-----------|---------------------------------------------------| | `conformer-mmi` | Conformer | | | `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | + +# CR-CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index bc7d8a5efb..6a669f072d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,315 @@ ## Results +### zipformer (zipformer + pruned-transducer w/ CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 | +| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 300 \ + --decoding-method $m +done +``` + +### zipformer (zipformer + CR-CTC-AED) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-aed \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 20 \ + --exp-dir zipformer/exp-large-cr-ctc-aed/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 200 \ + --decoding-method attention-decoder-rescoring-no-ngram +done +``` + +### zipformer (zipformer + CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 | + +The training command using 2 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-small/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --base-lr 0.04 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 850 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 25 \ + --exp-dir zipformer/exp-small \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 | + +The training command using 4 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 700 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 24 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + ### zipformer (zipformer + CTC/AED) See for more details. diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3a8995c811..c074c32ec7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -45,11 +45,10 @@ --max-duration 1000 It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` - - ctc loss & attention decoder loss, no transducer loss, - with `--use-transducer False --use-ctc True --use-attention-decoder True` + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) """ From 2fc53cd7ce291127cb04cac4dcb185ff0ff2ceb6 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 25 Nov 2024 15:28:05 +0800 Subject: [PATCH 7/7] add conformer exps --- egs/librispeech/ASR/conformer_ctc2/decode.py | 24 +- egs/librispeech/ASR/conformer_ctc3/decode.py | 5 +- .../ASR/conformer_ctc3/model_cr_ctc.py | 209 ++++ .../ASR/conformer_ctc3/train_cr_ctc.py | 1098 +++++++++++++++++ icefall/decode.py | 179 +++ icefall/utils.py | 3 +- 6 files changed, 1514 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 0b271a51cf..473d4eaab1 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -44,6 +44,7 @@ nbest_oracle, one_best_decoding, rescore_with_attention_decoder, + rescore_with_attention_decoder_no_ngram_old, rescore_with_n_best_list, rescore_with_rnn_lm, rescore_with_whole_lattice, @@ -459,6 +460,27 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram_old( + lattice=lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -761,7 +783,7 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id - if params.method == "ctc-decoding" or params.method == "ctc-greedy-search": + if params.method == "ctc-decoding" or params.method == "ctc-greedy-search" or params.method == "attention-decoder-rescoring-no-ngram": HLG = None H = k2.ctc_topo( max_token=max_token_id, diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index e6327bb5ec..0355bcccae 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -72,7 +72,7 @@ import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_ctc_model, get_params +from train_cr_ctc import add_model_arguments, get_ctc_model, get_params from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import ( @@ -458,8 +458,9 @@ def decode_one_batch( else: encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) - nnet_output = model.get_ctc_output(encoder_out) + # nnet_output = model.get_ctc_output(encoder_out) # nnet_output is (N, T, C) + nnet_output = model.ctc_output(encoder_out) # (N, T, C) if params.decoding_method == "ctc-greedy-search": timestamps, hyps = ctc_greedy_search( diff --git a/egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py b/egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py new file mode 100644 index 0000000000..02ccf16b97 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py @@ -0,0 +1,209 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear +from icefall.utils import make_pad_mask, time_warp +from lhotse.dataset import SpecAugment + + +class CTCModel(nn.Module): + """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf + "Connectionist Temporal Classification: Labelling Unsegmented + Sequence Data with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + encoder_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + encoder_dim: + The feature embedding dimension. + vocab_size: + The vocabulary size. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder = encoder + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + return ctc_loss + + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + warmup: float = 1.0, + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + if use_cr_ctc: + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + # Compute CTC loss + targets = y.values + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 + + return ctc_loss, cr_loss diff --git a/egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py b/egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py new file mode 100755 index 0000000000..1cdaa70add --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py @@ -0,0 +1,1098 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 550 + +# train a streaming model +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 \ + --delay-penalty 0.0 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_cr_ctc import CTCModel +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from lhotse.dataset import SpecAugment +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) + +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc3/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_ctc_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + model = CTCModel( + encoder=encoder, + encoder_dim=params.encoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, + spec_augment: Optional[SpecAugment] = None, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + + with torch.set_grad_enabled(is_training): + ctc_loss, cr_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + warmup=warmup, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + ) + + loss = ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + spec_augment=spec_augment, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if params.use_cr_ctc: + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts + # strictly speaking, shuffled training cuts should be used instead + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + spec_augment=spec_augment, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + spec_augment=spec_augment, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, + spec_augment: Optional[SpecAugment] = None, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + spec_augment=spec_augment, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/decode.py b/icefall/decode.py index dd3af1e99b..b3ce147626 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1083,6 +1083,185 @@ def rescore_with_attention_decoder( return ans +def rescore_with_attention_decoder_no_ngram_old( + lattice: k2.Fsa, + num_paths: int, + model: torch.nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], + sos_id: int, + eos_id: int, + attention_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(T, N, C)`. + memory_key_padding_mask: + The padding mask for memory with shape `(N, T)`. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + # max_loop_count = 10 + # loop_count = 0 + # while loop_count <= max_loop_count: + # try: + # nbest = Nbest.from_lattice( + # lattice=lattice, + # num_paths=num_paths, + # use_double_scores=use_double_scores, + # nbest_scale=nbest_scale, + # ) + # # nbest.fsa.scores are all 0s at this point + # nbest = nbest.intersect(lattice) + # break + # except RuntimeError as e: + # logging.info(f"Caught exception:\n{e}\n") + # logging.info(f"num_paths before decreasing: {num_paths}") + # num_paths = int(num_paths / 2) + # if loop_count >= max_loop_count or num_paths <= 0: + # logging.info("Return None as the resulting lattice is too large.") + # return None + # logging.info( + # "This OOM is not an error. You can ignore it. " + # "If your model does not converge well, or --max-duration " + # "is too large, or the input sound file is difficult to " + # "decode, you will meet this exception." + # ) + # logging.info(f"num_paths after decreasing: {num_paths}") + # loop_count += 1 + + # # Now nbest.fsa has its scores set. + # # Also, nbest.fsa inherits the attributes from `lattice`. + # assert hasattr(nbest.fsa, "lm_scores") + + # am_scores = nbest.compute_am_scores() + # ngram_lm_scores = nbest.compute_lm_scores() + + # # The `tokens` attribute is set inside `compile_hlg.py` + # assert hasattr(nbest.fsa, "tokens") + # assert isinstance(nbest.fsa.tokens, torch.Tensor) + + # path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # # the shape of memory is (T, N, C), so we use axis=1 here + # expanded_memory = memory.index_select(1, path_to_utt_map) + + # if memory_key_padding_mask is not None: + # # The shape of memory_key_padding_mask is (N, T), so we + # # use axis=0 here. + # expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + # 0, path_to_utt_map + # ) + # else: + # expanded_memory_key_padding_mask = None + + # # remove axis corresponding to states. + # tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + # tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + # tokens = tokens.remove_values_leq(0) + # token_ids = tokens.tolist() + + # if len(token_ids) == 0: + # print("Warning: rescore_with_attention_decoder(): empty token-ids") + # return None + + # path is a ragged tensor with dtype torch.int32. + # It has three axes [utt][path][arc_pos] + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + # Note that labels, aux_labels and scores contains 0s and -1s. + # The last entry in each sublist is -1. + # The axes are [path][token_id] + labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0) + aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0) + scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0) + + # Remove -1 from labels as we will use it to construct a linear FSA + labels = labels.remove_values_eq(-1) + fsa = k2.linear_fsa(labels) + fsa.aux_labels = aux_labels.values + + # utt_to_path_shape has axes [utt][path] + utt_to_path_shape = path.shape.get_layer(0) + scores = k2.RaggedTensor(utt_to_path_shape, scores.sum()) + + path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long) + # the shape of memory is (N, T, C), so we use axis=0 here + # expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map) + # expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map) + # # the shape of memory is (T, N, C), so we use axis=1 here + expanded_memory = memory.index_select(1, path_to_utt_map) + + if memory_key_padding_mask is not None: + # The shape of memory_key_padding_mask is (N, T), so we + # use axis=0 here. + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_utt_map + ) + else: + expanded_memory_key_padding_mask = None + + token_ids = aux_labels.remove_values_leq(0).tolist() + + nll = model.decoder_nll( + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + for a_scale in attention_scale_list: + tot_scores = scores.values + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(fsa, max_indexes) + + key = f"attention_scale_{a_scale}" + ans[key] = best_path + + return ans + + def rescore_with_attention_decoder_with_ngram( lattice: k2.Fsa, num_paths: int, diff --git a/icefall/utils.py b/icefall/utils.py index b0a42cefaa..665e33674c 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -983,7 +983,8 @@ def write_error_stats_with_timestamps( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate), float(mean_delay), float(var_delay) + # return float(tot_err_rate), float(mean_delay), float(var_delay) + return float(tot_err_rate), mean_delay, var_delay def write_surt_error_stats(