From 11d4ca6eee13c38aa7e11efb9aafbba0285676f4 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Sep 2022 10:41:38 +0800 Subject: [PATCH 1/7] Copy files --- .../__init__.py | 1 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../conformer.py | 1 + .../pruned_transducer_stateless_mbr/decode.py | 869 +++++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless_mbr/export.py | 290 +++++ .../pruned_transducer_stateless_mbr/joiner.py | 1 + .../pruned_transducer_stateless_mbr/model.py | 1 + .../pruned_transducer_stateless_mbr/optim.py | 1 + .../scaling.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 663 ++++++++++ .../test_model.py | 1 + .../pruned_transducer_stateless_mbr/train.py | 1142 +++++++++++++++++ 17 files changed, 2977 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/beam_search.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/conformer.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode_stream.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless_mbr/export.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/optim.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/__init__.py new file mode 120000 index 0000000000..b24e5e3572 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/__init__.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/asr_datamodule.py new file mode 120000 index 0000000000..a074d60850 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/beam_search.py new file mode 120000 index 0000000000..8554e44ccf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/conformer.py new file mode 120000 index 0000000000..3b84b95739 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/conformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py new file mode 100755 index 0000000000..8431492e6d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py @@ -0,0 +1,869 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless4/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(8) decode in streaming mode (take greedy search as an example) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method greedy_search + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode_stream.py new file mode 120000 index 0000000000..30f2648139 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decoder.py new file mode 120000 index 0000000000..0793c5709c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/encoder_interface.py new file mode 120000 index 0000000000..b9aa0ae083 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/export.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/export.py new file mode 100755 index 0000000000..ce7518ceb4 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/export.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless4/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless4/decode.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model True +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.streaming_model: + assert params.causal_convolution + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py new file mode 120000 index 0000000000..815fd4bb6f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py new file mode 120000 index 0000000000..ebb6d774d9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/optim.py new file mode 120000 index 0000000000..e2deb44925 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/scaling.py new file mode 120000 index 0000000000..09d802cc44 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_beam_search.py new file mode 120000 index 0000000000..3a5f898338 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_decode.py new file mode 100755 index 0000000000..7af9ea9b8d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/streaming_decode.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless4/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --left-context 32 \ + --decode-chunk-size 8 \ + --right-context 0 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --decoding_method greedy_search \ + --num-decode-streams 200 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-context", + type=int, + default=0, + help="right context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames( + params.decode_chunk_size * params.subsampling_factor + ) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # if T is less than 7 there will be an error in time reduction layer, + # because we subsample features with ((x_len - 1) // 2 - 1) // 2 + # we plus 2 here because we will cut off one frame on each size of + # encoder_embed output as they see invalid paddings. so we need extra 2 + # frames. + tail_length = 7 + (2 + params.right_context) * params.subsampling_factor + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = [ + torch.stack([x[0] for x in states], dim=2), + torch.stack([x[1] for x in states], dim=2), + ] + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + left_context=params.left_context, + right_context=params.right_context, + processed_lens=processed_lens, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = [states[0][i], states[1][i]] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + params.suffix += f"-right-context-{params.right_context}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + # Decoding in streaming requires causal convolution + params.causal_convolution = True + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/test_model.py new file mode 120000 index 0000000000..4196e587cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/test_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py new file mode 100755 index 0000000000..13a5b1a515 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -0,0 +1,1142 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 550 + +# train a streaming model +./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 + +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 344c0e7521fcf40e4c901773c8667f75488f68e5 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 29 Sep 2022 11:42:00 +0800 Subject: [PATCH 2/7] Add delta wer training pipeline --- .../pruned_transducer_stateless_mbr/joiner.py | 86 ++- .../pruned_transducer_stateless_mbr/model.py | 525 +++++++++++++++++- .../pruned_transducer_stateless_mbr/train.py | 26 +- icefall/utils.py | 2 +- 4 files changed, 634 insertions(+), 5 deletions(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py deleted file mode 120000 index 815fd4bb6f..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py new file mode 100644 index 0000000000..5983098521 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py @@ -0,0 +1,85 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + +from icefall.utils import is_jit_tracing + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size) + self.output_linear_wer = ScaledLinear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + extra_output: bool = False, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + extra_output: + If true, return an extra tensor with shape (N, T, s_range, C) which + will be used to calculate the `delta_wer`. + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + If extra_output is False, return a tensor of shape (N, T, s_range, C). + If extra_output is True, return two tensors of the same shape + (N, T, s_range, C). The two tensors produced by two different + projection layers, one is used as the regular joiner output, the other + is used to calculate the `delta_wer` needed by MBR training. + """ + if not is_jit_tracing(): + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.shape == decoder_out.shape + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = torch.tanh(logit) + + joiner_logit = self.output_linear(logit) + + if not extra_output: + return joiner_logit + else: + wer_logit = self.output_linear_wer(logit) + return joiner_logit, wer_logit diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py deleted file mode 120000 index ebb6d774d9..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py new file mode 100644 index 0000000000..d137fe4f23 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -0,0 +1,524 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import logging +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from torch.distributions.categorical import Categorical +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +def _roll_by_shifts( + src: torch.Tensor, shifts: torch.LongTensor +) -> torch.Tensor: + """Roll tensor with different shifts for each row. + Note: + We assume the src is a 3 dimensions tensor and roll the last dimension. + Example: + >>> src = torch.arange(15).reshape((1,3,5)) + >>> src + tensor([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]]]) + >>> shift = torch.tensor([[1, 2, 3]]) + >>> shift + tensor([[1, 2, 3]]) + >>> _roll_by_shifts(src, shift) + tensor([[[ 4, 0, 1, 2, 3], + [ 8, 9, 5, 6, 7], + [12, 13, 14, 10, 11]]]) + """ + assert src.dim() == 3 + (B, T, S) = src.shape + assert shifts.shape == (B, T) + + index = ( + torch.arange(S, device=src.device) + .view((1, S)) + .repeat((T, 1)) + .repeat((B, 1, 1)) + ) + index = (index - shifts.reshape(B, T, 1)) % S + return torch.gather(src, 2, index) + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + self.decoder_dim = decoder_dim + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def delta_wer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + init_context: torch.Tensor, + y_padded: torch.Tensor, + y_lens: torch.Tensor, + path_length: int = 20, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + encoder_out: + The output of the encoder whose shape is (batch_size, T, encoder_dim) + encoder_out_lens: + A tensor of shape (batch_size,) containing the number of frames + before padding. + init_context: + A tensor of shape (batch_size, T, context_size) containing the + initial history symbols for each frame. + y_padded: + The transcripts whose shape is (batch_size, S). + y_lens: + A tensor of shape (batch_size,) containing the number of symbols + before padding. + path_length: + The length of the sampled paths. + + Returns: + Return three tensors, + - The delta_wer, its shape is (batch_size,) + - The absolute value of wer_diff, its shape is (batch_size,) + - The absolute value of pred_wer_diff, its shape is (batch_size,). + """ + batch_size, T, encoder_dim = encoder_out.shape + device = encoder_out.device + + blank_id = self.decoder.blank_id + context_size = self.decoder.context_size + decoder_dim = self.decoder_dim + vocab_size = self.decoder.vocab_size + + max_frame = torch.max(encoder_out_lens).item() + min_frame = torch.min(encoder_out_lens).item() + + # t_index contains the frame ids we are sampling for each path. + # shape : (batch_size, 1) + t_index = torch.randint( + min_frame // 2, max_frame // 2 + 1, (batch_size, 1), device=device + ) + t_index = torch.remainder(t_index, encoder_out_lens.reshape(-1, 1)) + + # we will sample two paths for each sequence + num_paths = 2 + # shape : (batch_size, num_paths) + t_index = t_index.expand(batch_size, num_paths) + + # The max frame index for each path + # shape : (batch_size, num_paths) + t_index_max = encoder_out_lens.view(batch_size, 1).expand( + batch_size, num_paths + ) + + # left_symbols contains the left contexts of decoder for each path + # shape : (batch_size, num_paths, context_size) + left_symbols = torch.gather( + init_context, + dim=1, + index=t_index.unsqueeze(2).expand( + batch_size, num_paths, context_size + ), + ) + left_symbols = left_symbols.view(batch_size * num_paths, context_size) + + # It has a shape of (batch_size,) indicating whether having different + # paths for this sequence. + has_diff = torch.zeros((batch_size,), device=device).bool() + # It has a shape of (batch_size,) indicating whether reaching final + # for this sequence + reach_final = torch.zeros((batch_size,), device=device).bool() + + # The pred_wer, default zeros. If there is no different symbol for the + # sampled paths, the pred_wers for the two paths are the same. + pred_output = torch.zeros((batch_size, num_paths), device=device) + + sampled_paths_list = [] + + while len(sampled_paths_list) < path_length: + # (B, num_paths, encoder_dim) + current_encoder_out = torch.gather( + encoder_out, + dim=1, + index=t_index.unsqueeze(2).expand( + batch_size, num_paths, encoder_dim + ), + ) + + # (B, num_paths, decoder_dim) + decoder_output = self.decoder(left_symbols, need_pad=False).view( + batch_size, num_paths, decoder_dim + ) + # joiner_output : (B, num_paths, V); + # wer_output : (B, num_paths, V) + joiner_output, wer_output = self.joiner( + current_encoder_out, decoder_output, extra_output=True + ) + + probs = torch.softmax(joiner_output, -1) + # sampler: https://pytorch.org/docs/stable/distributions.html#categorical + sampler = Categorical(probs=probs) + + # sample one symbol for each path + # index : (batch_size, num_paths) + index = sampler.sample() + + # The two paths have different symbols. + mask = index[:, 0] != index[:, 1] + # shape : (batch_size,), will only be True when the two paths have + # different symbols in the first time. + meet_diff = mask & ~has_diff & ~reach_final + + has_diff |= mask + + wer_output = torch.gather( + wer_output, dim=2, index=index.unsqueeze(2) + ).squeeze(2) + + # we only get the pred_wer at the position where the two paths start + # to have different symbols. + pred_output = torch.where( + meet_diff.reshape(batch_size, 1), wer_output, pred_output + ) + + # update (t, s) for each path + # index == 0 means the sampled symbol is blank + t_mask = index == 0 + t_index = t_index + 1 + + # if reaching final, we will ignore the sampled symbols, just append + # blank_id to the paths. + index = torch.where( + reach_final.reshape(batch_size, 1).expand( + batch_size, num_paths + ), + blank_id, + index, + ) + sampled_paths_list.append(index) + + final_mask = t_index >= t_index_max + + # Set reach_final to true when one of the paths reaching final. + reach_final |= final_mask[:, 0] | final_mask[:, 1] + + t_index.masked_fill_(final_mask, 0) + + left_symbols = left_symbols.view( + batch_size, num_paths, context_size + ) + current_symbols = torch.cat( + [ + left_symbols, + index.unsqueeze(2), + ], + dim=2, + ) + # if the sampled symbol is blank, we only need to roll the history + # symbols, if the sampled symbol is not blank, append the newly + # sampled symbol. + left_symbols = _roll_by_shifts( + current_symbols, t_mask.to(torch.int64) + ) + left_symbols = left_symbols[:, :, 1:] + + left_symbols = left_symbols.view( + batch_size * num_paths, context_size + ) + + # sampled_paths : (batch_size, num_paths, path_lengths) + sampled_paths = torch.stack(sampled_paths_list, dim=2).int() + + px1 = k2.RaggedTensor(sampled_paths[:, 0, :]) + px1 = px1.remove_values_eq(blank_id) + row_splits = px1.shape.row_splits(1) + px1_lens = row_splits[1:] - row_splits[:-1] + px1 = px1.pad(mode="constant", padding_value=blank_id) + + boundary = torch.cat( + [ + torch.zeros((batch_size, 2), dtype=torch.int64, device=device), + px1_lens.reshape(batch_size, 1), + y_lens.reshape(batch_size, 1), + ], + dim=1, + ) + + wer1 = k2.levenshtein_distance( + px=px1, py=y_padded.int(), boundary=boundary + ) + wer1 = torch.gather( + wer1, + 1, + boundary[:, 2] + .reshape(batch_size, 1, 1) + .expand(batch_size, 1, wer1.size(2)), + ).squeeze(1) + wer1 = torch.gather( + wer1, 1, boundary[:, 3].reshape(batch_size, 1) + ).squeeze(1) + + px2 = k2.RaggedTensor(sampled_paths[:, 1, :]) + px2 = px2.remove_values_eq(blank_id) + row_splits = px2.shape.row_splits(1) + px2_lens = row_splits[1:] - row_splits[:-1] + px2 = px2.pad(mode="constant", padding_value=blank_id) + + boundary = torch.cat( + [ + torch.zeros((batch_size, 2), dtype=torch.int64, device=device), + px2_lens.reshape(batch_size, 1), + y_lens.reshape(batch_size, 1), + ], + dim=1, + ) + + wer2 = k2.levenshtein_distance( + px=px2, py=y_padded.int(), boundary=boundary + ) + wer2 = torch.gather( + wer2, + 1, + boundary[:, 2] + .reshape(batch_size, 1, 1) + .expand(batch_size, 1, wer2.size(2)), + ).squeeze(1) + wer2 = torch.gather( + wer2, 1, boundary[:, 3].reshape(batch_size, 1) + ).squeeze(1) + + delta_wer = torch.pow( + ((wer1 - wer2) - (pred_output[:, 0] - pred_output[:, 1])), 2 + ) + + return ( + delta_wer, + torch.abs(wer1 - wer2), + torch.abs(pred_output[:, 0] - pred_output[:, 1]), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + path_length: int = 20, + warmup: float = 1.0, + reduction: str = "sum", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + context_size = self.decoder.context_size + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + modified=True, + reduction=reduction, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + modified=True, + reduction=reduction, + ) + + # Get contexts for each frame according to the gradients, just like we + # do for getting pruning bounds. + (B, S, T1) = px_grad.shape + T = py_grad.shape[-1] + # shape : (B, S, T) + tot_grad = px_grad[:, :, :T] + py_grad[:, :S, :] + # shape : (B, T) + best_idx = torch.argmax(tot_grad, dim=1) + # shape : (B, T, context_size) + state_idx = best_idx.reshape((B, T, 1)).expand( + (B, T, context_size) + ) + torch.arange(context_size, device=px_grad.device) + # shape : (B, context_size) + init_context = torch.tensor( + [blank_id], dtype=torch.int64, device=px_grad.device + ).expand(B, context_size) + # shape : (B, S + context_size) + sos_y_padded = torch.cat([init_context, y_padded], dim=1) + init_context = torch.gather( + sos_y_padded.unsqueeze(1).expand(B, T, S + context_size), + dim=2, + index=state_idx, + ) + + delta_wer, wer_diff, pred_wer_diff = self.delta_wer( + encoder_out=encoder_out, + encoder_out_lens=x_lens, + init_context=init_context, + y_padded=y_padded, + y_lens=y_lens, + path_length=path_length, + ) + + return (simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py index 13a5b1a515..70c861e7a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -335,6 +335,20 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--path-length", + type=int, + default=20, + help="The length of the sampled paths for MBR training.", + ) + + parser.add_argument( + "--delta-wer-scale", + type=float, + default=0.1, + help="The scale applying to delta_wer when it adds to the loss", + ) + add_model_arguments(parser) return parser @@ -391,7 +405,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 50, + "log_interval": 1, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for conformer @@ -628,13 +642,14 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + path_length=params.path_length, warmup=warmup, reduction="none", ) @@ -663,6 +678,9 @@ def compute_loss( simple_loss = simple_loss.sum() pruned_loss = pruned_loss.sum() + delta_wer = delta_wer.sum() + wer_diff = wer_diff.sum() + pred_wer_diff = pred_wer_diff.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -675,6 +693,7 @@ def compute_loss( loss = ( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + params.delta_wer_scale * delta_wer ) assert loss.requires_grad == is_training @@ -703,6 +722,9 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["delta_wer"] = delta_wer.detach().cpu().item() + info["wer_diff"] = wer_diff.detach().cpu().item() + info["pred_wer_diff"] = pred_wer_diff.detach().cpu().item() return loss, info diff --git a/icefall/utils.py b/icefall/utils.py index ad079222e5..f6bb2087cf 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -577,7 +577,7 @@ def norm_items(self) -> List[Tuple[str, float]]: continue norm_value = ( float(v) / num_frames - if "utt_" not in k + if "utt_" not in k and "wer" not in k else float(v) / num_utterances ) ans.append((k, norm_value)) From 32b8a027182bff551025b724e85b2fb7149f338e Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 13 Nov 2022 21:24:27 +0800 Subject: [PATCH 3/7] Add enhanced embedding module --- .../attention.py | 1 + .../pruned_transducer_stateless_mbr/joiner.py | 86 +----- .../label_smoothing.py | 1 + .../pruned_transducer_stateless_mbr/model.py | 289 +++++++++++++++--- .../subsampling.py | 1 + .../pruned_transducer_stateless_mbr/train.py | 82 ++++- .../transformer.py | 1 + 7 files changed, 330 insertions(+), 131 deletions(-) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/attention.py mode change 100644 => 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/label_smoothing.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/subsampling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless_mbr/transformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/attention.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/attention.py new file mode 120000 index 0000000000..e5f4a06447 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/attention.py @@ -0,0 +1 @@ +../conformer_ctc2/attention.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py deleted file mode 100644 index 5983098521..0000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - -from icefall.utils import is_jit_tracing - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) - self.output_linear = ScaledLinear(joiner_dim, vocab_size) - self.output_linear_wer = ScaledLinear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - extra_output: bool = False, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - extra_output: - If true, return an extra tensor with shape (N, T, s_range, C) which - will be used to calculate the `delta_wer`. - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - If extra_output is False, return a tensor of shape (N, T, s_range, C). - If extra_output is True, return two tensors of the same shape - (N, T, s_range, C). The two tensors produced by two different - projection layers, one is used as the regular joiner output, the other - is used to calculate the `delta_wer` needed by MBR training. - """ - if not is_jit_tracing(): - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.shape == decoder_out.shape - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) - else: - logit = encoder_out + decoder_out - - logit = torch.tanh(logit) - - joiner_logit = self.output_linear(logit) - - if not extra_output: - return joiner_logit - else: - wer_logit = self.output_linear_wer(logit) - return joiner_logit, wer_logit diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py new file mode 120000 index 0000000000..815fd4bb6f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/label_smoothing.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/label_smoothing.py new file mode 120000 index 0000000000..6eadf50539 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc2/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py index d137fe4f23..99ea794be5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -15,17 +15,33 @@ # limitations under the License. -from typing import Tuple +from typing import Dict, List, Optional, Tuple import logging import k2 import torch import torch.nn as nn + from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledEmbedding, + ScaledLinear, +) from torch.distributions.categorical import Categorical -from scaling import ScaledLinear - -from icefall.utils import add_sos +from transformer import ( + PositionalEncoding, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, + decoder_padding_mask, + generate_square_subsequent_mask, +) +from icefall.utils import add_sos, add_eos, make_pad_mask def _roll_by_shifts( @@ -72,6 +88,9 @@ def __init__( encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + quasi_joiner: nn.Module, + transformer_lm: nn.Module, + embedding_enhancer: nn.Module, encoder_dim: int, decoder_dim: int, joiner_dim: int, @@ -101,6 +120,9 @@ def __init__( self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.quasi_joiner = quasi_joiner + self.transformer_lm = transformer_lm + self.embedding_enhancer = embedding_enhancer self.decoder_dim = decoder_dim self.simple_am_proj = ScaledLinear( @@ -361,6 +383,8 @@ def forward( x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + sos_id: int, + eos_id: int, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -419,6 +443,7 @@ def forward( blank_id = self.decoder.blank_id context_size = self.decoder.context_size + vocab_size = self.decoder.vocab_size sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. @@ -450,7 +475,7 @@ def forward( lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - modified=True, + rnnt_type="constrained", reduction=reduction, return_grad=True, ) @@ -484,41 +509,233 @@ def forward( ranges=ranges, termination_symbol=blank_id, boundary=boundary, - modified=True, + rnnt_type="constrained", reduction=reduction, ) - # Get contexts for each frame according to the gradients, just like we - # do for getting pruning bounds. - (B, S, T1) = px_grad.shape - T = py_grad.shape[-1] - # shape : (B, S, T) - tot_grad = px_grad[:, :, :T] + py_grad[:, :S, :] - # shape : (B, T) - best_idx = torch.argmax(tot_grad, dim=1) - # shape : (B, T, context_size) - state_idx = best_idx.reshape((B, T, 1)).expand( - (B, T, context_size) - ) + torch.arange(context_size, device=px_grad.device) - # shape : (B, context_size) - init_context = torch.tensor( - [blank_id], dtype=torch.int64, device=px_grad.device - ).expand(B, context_size) - # shape : (B, S + context_size) - sos_y_padded = torch.cat([init_context, y_padded], dim=1) - init_context = torch.gather( - sos_y_padded.unsqueeze(1).expand(B, T, S + context_size), - dim=2, - index=state_idx, + text_embedding, text_embedding_key_padding_mask = self.transformer_lm( + y, sos_id=sos_id, eos_id=eos_id + ) + + embedding_key_padding_mask = make_pad_mask(x_lens) + enhanced_embedding = self.embedding_enhancer( + embedding=encoder_out.detach(), + text_embedding=text_embedding, + embedding_key_padding_mask=embedding_key_padding_mask, + text_embedding_key_padding_mask=text_embedding_key_padding_mask, + warmup=warmup, + ) + + l2_loss = ( + torch.sum(torch.pow(enhanced_embedding - encoder_out.detach(), 2)) + / vocab_size + ) + + if False: + # Get contexts for each frame according to the gradients, just like we + # do for getting pruning bounds. + (B, S, T1) = px_grad.shape + T = py_grad.shape[-1] + # shape : (B, S, T) + tot_grad = px_grad[:, :, :T] + py_grad[:, :S, :] + # shape : (B, T) + best_idx = torch.argmax(tot_grad, dim=1) + # shape : (B, T, context_size) + state_idx = best_idx.reshape((B, T, 1)).expand( + (B, T, context_size) + ) + torch.arange(context_size, device=px_grad.device) + # shape : (B, context_size) + init_context = torch.tensor( + [blank_id], dtype=torch.int64, device=px_grad.device + ).expand(B, context_size) + # shape : (B, S + context_size) + sos_y_padded = torch.cat([init_context, y_padded], dim=1) + init_context = torch.gather( + sos_y_padded.unsqueeze(1).expand(B, T, S + context_size), + dim=2, + index=state_idx, + ) + + delta_wer, wer_diff, pred_wer_diff = self.delta_wer( + encoder_out=encoder_out, + encoder_out_lens=x_lens, + init_context=init_context, + y_padded=y_padded, + y_lens=y_lens, + path_length=path_length, + ) + + return ( + simple_loss, + pruned_loss, + delta_wer, + wer_diff, + pred_wer_diff, + ) + + return (simple_loss, pruned_loss, l2_loss) + + +class TransformerLM(nn.Module): + def __init__( + self, + num_classes: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 3, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + """ + Args: + num_classes: + The output dimension of the model. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + dropout: + Dropout in encoder/decoder. + layer_dropout (float): layer-dropout rate. + """ + super().__init__() + + self.num_classes = num_classes + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + self.embed = ScaledEmbedding( + num_embeddings=self.num_classes, embedding_dim=d_model + ) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.encoder = TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + ) + + self.encoder_output_layer = ScaledLinear( + d_model, num_classes, bias=True ) - delta_wer, wer_diff, pred_wer_diff = self.delta_wer( - encoder_out=encoder_out, - encoder_out_lens=x_lens, - init_context=init_context, - y_padded=y_padded, - y_lens=y_lens, - path_length=path_length, + def forward( + self, + y: k2.RaggedTensor, + sos_id: int, + eos_id: int, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + Return a tuple containing 2 tensors: + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + """ + assert y.num_axes == 2, y.num_axes + + device = y.device + + sos_y = add_sos(y, sos_id=sos_id) + sos_y_padded = sos_y.pad(mode="constant", padding_value=sos_id) + y_eos = add_eos(y, eos_id=eos_id) + y_eos_padded = y_eos.pad(mode="constant", padding_value=eos_id) + + att_mask = generate_square_subsequent_mask(y_eos_padded.shape[-1]).to( + device + ) + + key_padding_mask = decoder_padding_mask(y_eos_padded, ignore_id=eos_id) + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + key_padding_mask[:, 0] = False + + x = self.embed(sos_y_padded) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + x = self.encoder( + x, + mask=att_mask, + src_key_padding_mask=key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + return x, key_padding_mask + + +class EmbeddingEnhancer(nn.Module): + """ + Enhance the encoder embedding to "knows about" the text as well as the acoustics. + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + num_layers: int = 4, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ): + super().__init__() + self.encoder_pos = PositionalEncoding(d_model, dropout) + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.enhancer = TransformerDecoder(decoder_layer, num_layers) + + def forward( + self, + embedding: torch.Tensor, + text_embedding: torch.Tensor, + embedding_mask: Optional[torch.Tensor] = None, + text_embedding_mask: Optional[torch.Tensor] = None, + embedding_key_padding_mask: Optional[torch.Tensor] = None, + text_embedding_key_padding_mask: Optional[torch.Tensor] = None, + mask_proportion: float = 0.25, + warmup: float = 1.0, + ): + N, T, C = embedding.shape + mask = torch.randn((N, T, C), device=embedding.device) + mask = mask > mask_proportion + masked_embedding = torch.masked_fill(embedding, ~mask, 0.0) + masked_embedding = self.encoder_pos(masked_embedding) + masked_embedding = masked_embedding.permute(1, 0, 2) + + enhanced_embedding = self.enhancer( + tgt=masked_embedding, + memory=text_embedding, + tgt_mask=embedding_mask, + memory_mask=embedding_mask, + tgt_key_padding_mask=embedding_key_padding_mask, + memory_key_padding_mask=text_embedding_key_padding_mask, + warmup=warmup, ) - return (simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff) + # shape: (N, T, C) + enhanced_embedding = enhanced_embedding.permute(1, 0, 2) + return enhanced_embedding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/subsampling.py new file mode 120000 index 0000000000..bf0d9c9b99 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc2/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py index 70c861e7a1..7b2ca90a15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -77,7 +77,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import Transducer +from model import Transducer, TransformerLM, EmbeddingEnhancer from optim import Eden, Eve from torch import Tensor from torch.cuda.amp import GradScaler @@ -141,6 +141,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="How many left context can be seen in chunks when calculating attention.", ) + parser.add_argument( + "--num-lm-layers", + type=int, + default=3, + help="The number of layers for transformer language model", + ) + + parser.add_argument( + "--num-enhancer-layers", + type=int, + default=3, + help="The number of layers of embedding enhancer model", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -349,6 +363,13 @@ def get_parser(): help="The scale applying to delta_wer when it adds to the loss", ) + parser.add_argument( + "--l2-loss-scale", + type=float, + default=0.1, + help="The scale applying to l2_loss of embedding and enhanced_embedding", + ) + add_model_arguments(parser) return parser @@ -465,15 +486,42 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_transformer_lm(params: AttributeDict) -> nn.Module: + lm = TransformerLM( + num_classes=params.vocab_size, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_lm_layers, + ) + return lm + + +def get_embedding_enhancer(params: AttributeDict) -> nn.Module: + enhancer = EmbeddingEnhancer( + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_layers=params.num_enhancer_layers, + ) + return enhancer + + def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) + quasi_joiner = get_joiner_model(params) + transformer_lm = get_transformer_lm(params) + enhancer = get_embedding_enhancer(params) model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + quasi_joiner=quasi_joiner, + transformer_lm=transformer_lm, + embedding_enhancer=enhancer, encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, @@ -642,10 +690,13 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff = model( + # simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff = model( + simple_loss, pruned_loss, l2_loss = model( x=feature, x_lens=feature_lens, y=y, + sos_id=params.sos_id, + eos_id=params.eos_id, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, @@ -678,9 +729,10 @@ def compute_loss( simple_loss = simple_loss.sum() pruned_loss = pruned_loss.sum() - delta_wer = delta_wer.sum() - wer_diff = wer_diff.sum() - pred_wer_diff = pred_wer_diff.sum() + # delta_wer = delta_wer.sum() + # wer_diff = wer_diff.sum() + # pred_wer_diff = pred_wer_diff.sum() + # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -690,10 +742,17 @@ def compute_loss( if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + # l2_loss_scale = ( + # 1.0 + # if warmup < 1.0 + # else (0.1 if warmup > 1.0 and warmup < 2.0 else params.l2_loss_scale)) + l2_loss_scale = params.l2_loss_scale + loss = ( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - + params.delta_wer_scale * delta_wer + + l2_loss_scale * l2_loss + # + params.delta_wer_scale * delta_wer ) assert loss.requires_grad == is_training @@ -722,9 +781,10 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - info["delta_wer"] = delta_wer.detach().cpu().item() - info["wer_diff"] = wer_diff.detach().cpu().item() - info["pred_wer_diff"] = pred_wer_diff.detach().cpu().item() + info["l2_loss"] = l2_loss.detach().cpu().item() + # info["delta_wer"] = delta_wer.detach().cpu().item() + # info["wer_diff"] = wer_diff.detach().cpu().item() + # info["pred_wer_diff"] = pred_wer_diff.detach().cpu().item() return loss, info @@ -953,8 +1013,10 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") + params.eos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if params.dynamic_chunk_training: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/transformer.py new file mode 120000 index 0000000000..afbb14aa9d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/transformer.py @@ -0,0 +1 @@ +../conformer_ctc2/transformer.py \ No newline at end of file From 4a572049c2858e5b5c5eac4fb0d40ce9de47af67 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 16 Nov 2022 10:05:55 +0800 Subject: [PATCH 4/7] Add predictor loss --- .../pruned_transducer_stateless_mbr/model.py | 410 +++++++++++------- .../pruned_transducer_stateless_mbr/train.py | 46 +- 2 files changed, 286 insertions(+), 170 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py index 99ea794be5..3cdf614b1a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# Copyright 2022 Xiaomi Corp. (author: Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -130,15 +130,98 @@ def __init__( ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + def get_wer( + self, + sampled_paths: torch.Tensor, + y_padded: torch.Tensor, + y_lens: torch.Tensor, + blank_id: int = 0, + ): + batch_size, path_length = sampled_paths.shape + assert y_padded.size(0) == batch_size, (y_padded.shape, batch_size) + assert y_lens.size(0) == batch_size, (y_lens.shape, batch_size) + + device = sampled_paths.device + + px = k2.RaggedTensor(sampled_paths) + px = px.remove_values_eq(blank_id) + row_splits = px.shape.row_splits(1) + px_lens = row_splits[1:] - row_splits[:-1] + px = px.pad(mode="constant", padding_value=blank_id) + + boundary = torch.cat( + [ + torch.zeros((batch_size, 2), dtype=torch.int64, device=device), + px_lens.reshape(batch_size, 1), + y_lens.reshape(batch_size, 1), + ], + dim=1, + ) + + # wer : (batch_size, S, U) + wer = k2.levenshtein_distance( + px=px, py=y_padded.int(), boundary=boundary + ) + + wer = torch.gather( + wer, + 1, + boundary[:, 2] + .reshape(wer.size(0), 1, 1) + .expand(wer.size(0), 1, wer.size(2)), + ).squeeze(1) + + wer = torch.gather( + wer, 1, boundary[:, 3].reshape(batch_size, 1) + ).squeeze(1) + + # wer: (batch_size,) + return wer + + def get_init_contexts( + self, + px_grad: torch.Tensor, + py_grad: torch.Tensor, + y_padded: torch.Tensor, + ): + context_size = self.decoder.context_size + blank_id = self.decoder.blank_id + # Get contexts for each frame according to the gradients, just like we + # do for getting pruning bounds. + (B, S, T1) = px_grad.shape + T = py_grad.shape[-1] + # shape : (B, S, T) + tot_grad = px_grad[:, :, :T] + py_grad[:, :S, :] + # shape : (B, T) + best_idx = torch.argmax(tot_grad, dim=1) + # shape : (B, T, context_size) + state_idx = best_idx.reshape((B, T, 1)).expand( + (B, T, context_size) + ) + torch.arange(context_size, device=px_grad.device) + # shape : (B, context_size) + init_context = torch.tensor( + [blank_id], dtype=torch.int64, device=px_grad.device + ).expand(B, context_size) + # shape : (B, S + context_size) + sos_y_padded = torch.cat([init_context, y_padded], dim=1) + init_context = torch.gather( + sos_y_padded.unsqueeze(1).expand(B, T, S + context_size), + dim=2, + index=state_idx, + ) + return init_context + def delta_wer( self, encoder_out: torch.Tensor, + enhanced_encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, init_context: torch.Tensor, y_padded: torch.Tensor, y_lens: torch.Tensor, + num_pairs: int = 10, path_length: int = 20, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: encoder_out: @@ -164,6 +247,7 @@ def delta_wer( - The absolute value of pred_wer_diff, its shape is (batch_size,). """ batch_size, T, encoder_dim = encoder_out.shape + assert y_padded.size(0) == batch_size, (y_padded.shape, batch_size) device = encoder_out.device blank_id = self.decoder.blank_id @@ -171,69 +255,92 @@ def delta_wer( decoder_dim = self.decoder_dim vocab_size = self.decoder.vocab_size - max_frame = torch.max(encoder_out_lens).item() - min_frame = torch.min(encoder_out_lens).item() - - # t_index contains the frame ids we are sampling for each path. - # shape : (batch_size, 1) - t_index = torch.randint( - min_frame // 2, max_frame // 2 + 1, (batch_size, 1), device=device - ) + # t_index contains the frame ids we are sampling for each pair of paths. + # shape : (batch_size, num_pairs) + t_index = torch.arange( + 0, T + num_pairs, int(T / num_pairs), device=device + )[0:num_pairs] + t_index = t_index.reshape((1, num_pairs)).expand(batch_size, num_pairs) t_index = torch.remainder(t_index, encoder_out_lens.reshape(-1, 1)) - # we will sample two paths for each sequence - num_paths = 2 - # shape : (batch_size, num_paths) - t_index = t_index.expand(batch_size, num_paths) - # The max frame index for each path - # shape : (batch_size, num_paths) - t_index_max = encoder_out_lens.view(batch_size, 1).expand( - batch_size, num_paths + # shape : (batch_size, num_pairs) + t_index_max = encoder_out_lens.view(batch_size, 1, 1).expand( + batch_size, num_pairs, 2 ) # left_symbols contains the left contexts of decoder for each path - # shape : (batch_size, num_paths, context_size) + # shape : (batch_size, num_pairs, context_size) left_symbols = torch.gather( init_context, dim=1, index=t_index.unsqueeze(2).expand( - batch_size, num_paths, context_size + batch_size, num_pairs, context_size ), ) - left_symbols = left_symbols.view(batch_size * num_paths, context_size) - # It has a shape of (batch_size,) indicating whether having different + # we will sample two paths for each sequence + t_index = t_index.reshape((batch_size, num_pairs, 1)).expand( + batch_size, num_pairs, 2 + ) + left_symbols = left_symbols.reshape( + (batch_size, num_pairs, 1, context_size) + ).expand(batch_size, num_pairs, 2, context_size) + + left_symbols = left_symbols.reshape( + batch_size * num_pairs * 2, context_size + ) + + # It has a shape of (batch_size, num_pairs) indicating whether having different # paths for this sequence. - has_diff = torch.zeros((batch_size,), device=device).bool() - # It has a shape of (batch_size,) indicating whether reaching final + has_diff = torch.zeros((batch_size, num_pairs), device=device).bool() + # It has a shape of (batch_size, num_pairs) indicating whether reaching final # for this sequence - reach_final = torch.zeros((batch_size,), device=device).bool() + reach_final = torch.zeros((batch_size, num_pairs), device=device).bool() # The pred_wer, default zeros. If there is no different symbol for the # sampled paths, the pred_wers for the two paths are the same. - pred_output = torch.zeros((batch_size, num_paths), device=device) + pred_wer = torch.zeros((batch_size, num_pairs, 2), device=device) + + dummy_output = torch.zeros( + (batch_size, num_pairs, vocab_size), device=device + ) sampled_paths_list = [] + sampled_joiner_list = [] + sampled_quasi_joiner_list = [] while len(sampled_paths_list) < path_length: - # (B, num_paths, encoder_dim) + # (B, num_pairs, 2, encoder_dim) current_encoder_out = torch.gather( - encoder_out, + encoder_out.view(batch_size, T, 1, encoder_dim).expand( + batch_size, T, 2, encoder_dim + ), + dim=1, + index=t_index.unsqueeze(3).expand( + batch_size, num_pairs, 2, encoder_dim + ), + ) + + current_enhanced_encoder_out = torch.gather( + enhanced_encoder_out.view(batch_size, T, 1, encoder_dim).expand( + batch_size, T, 2, encoder_dim + ), dim=1, - index=t_index.unsqueeze(2).expand( - batch_size, num_paths, encoder_dim + index=t_index.unsqueeze(3).expand( + batch_size, num_pairs, 2, encoder_dim ), ) - # (B, num_paths, decoder_dim) + # (B, num_pairs, 2, decoder_dim) decoder_output = self.decoder(left_symbols, need_pad=False).view( - batch_size, num_paths, decoder_dim + batch_size, num_pairs, 2, decoder_dim ) - # joiner_output : (B, num_paths, V); - # wer_output : (B, num_paths, V) - joiner_output, wer_output = self.joiner( - current_encoder_out, decoder_output, extra_output=True + # joiner_output : (B, num_pairs, 2, V); + joiner_output = self.joiner(current_encoder_out, decoder_output) + # quasi_joiner_output : (B, num_pairs, 2, V) + quasi_joiner_output = self.quasi_joiner( + current_enhanced_encoder_out, decoder_output ) probs = torch.softmax(joiner_output, -1) @@ -241,142 +348,144 @@ def delta_wer( sampler = Categorical(probs=probs) # sample one symbol for each path - # index : (batch_size, num_paths) + # index : (batch_size, num_pairs, 2) index = sampler.sample() # The two paths have different symbols. - mask = index[:, 0] != index[:, 1] - # shape : (batch_size,), will only be True when the two paths have + mask = index[:, :, 0] != index[:, :, 1] + + # shape : (batch_size, num_pairs), will only be True when the two paths have # different symbols in the first time. meet_diff = mask & ~has_diff & ~reach_final has_diff |= mask + # wer_output: (B, num_pairs, 2) wer_output = torch.gather( - wer_output, dim=2, index=index.unsqueeze(2) - ).squeeze(2) + quasi_joiner_output, dim=3, index=index.unsqueeze(3) + ).squeeze(3) # we only get the pred_wer at the position where the two paths start # to have different symbols. - pred_output = torch.where( - meet_diff.reshape(batch_size, 1), wer_output, pred_output + pred_wer = torch.where( + meet_diff.reshape(batch_size, num_pairs, 1), + wer_output, + pred_wer, ) # update (t, s) for each path # index == 0 means the sampled symbol is blank + # t_mask : (B, num_pairs, 2) t_mask = index == 0 t_index = t_index + 1 # if reaching final, we will ignore the sampled symbols, just append # blank_id to the paths. index = torch.where( - reach_final.reshape(batch_size, 1).expand( - batch_size, num_paths + reach_final.reshape(batch_size, num_pairs, 1).expand( + batch_size, num_pairs, 2 ), blank_id, index, ) sampled_paths_list.append(index) + joiner_output = torch.where( + reach_final.reshape(batch_size, num_pairs, 1).expand( + batch_size, num_pairs, vocab_size + ), + dummy_output, + joiner_output[:, :, 0, :], + ) + sampled_joiner_list.append(joiner_output) + + quasi_joiner_output = torch.where( + reach_final.reshape(batch_size, num_pairs, 1).expand( + batch_size, num_pairs, vocab_size + ), + dummy_output, + quasi_joiner_output[:, :, 0, :], + ) + sampled_quasi_joiner_list.append(quasi_joiner_output) + final_mask = t_index >= t_index_max # Set reach_final to true when one of the paths reaching final. - reach_final |= final_mask[:, 0] | final_mask[:, 1] + reach_final = ( + reach_final | final_mask[:, :, 0] | final_mask[:, :, 1] + ) t_index.masked_fill_(final_mask, 0) left_symbols = left_symbols.view( - batch_size, num_paths, context_size + batch_size, num_pairs, 2, context_size ) current_symbols = torch.cat( [ left_symbols, - index.unsqueeze(2), + index.unsqueeze(3), ], - dim=2, + dim=3, ) # if the sampled symbol is blank, we only need to roll the history # symbols, if the sampled symbol is not blank, append the newly # sampled symbol. left_symbols = _roll_by_shifts( - current_symbols, t_mask.to(torch.int64) + current_symbols.view( + batch_size, num_pairs * 2, context_size + 1 + ), + t_mask.view(batch_size, num_pairs * 2).to(torch.int64), ) left_symbols = left_symbols[:, :, 1:] left_symbols = left_symbols.view( - batch_size * num_paths, context_size + batch_size * num_pairs * 2, context_size ) - # sampled_paths : (batch_size, num_paths, path_lengths) - sampled_paths = torch.stack(sampled_paths_list, dim=2).int() + # sampled_paths : (batch_size, num_pairs, 2, path_lengths) + sampled_paths = torch.stack(sampled_paths_list, dim=3).int() - px1 = k2.RaggedTensor(sampled_paths[:, 0, :]) - px1 = px1.remove_values_eq(blank_id) - row_splits = px1.shape.row_splits(1) - px1_lens = row_splits[1:] - row_splits[:-1] - px1 = px1.pad(mode="constant", padding_value=blank_id) + # sampled_joiner : (batch_size, num_pairs, path_lengths, vocab_size) + sampled_joiner = torch.stack(sampled_joiner_list, dim=2) + sampled_quasi_joiner = torch.stack(sampled_quasi_joiner_list, dim=2) - boundary = torch.cat( - [ - torch.zeros((batch_size, 2), dtype=torch.int64, device=device), - px1_lens.reshape(batch_size, 1), - y_lens.reshape(batch_size, 1), - ], - dim=1, + y_padded_expand = ( + y_padded.reshape(y_padded.size(0), 1, y_padded.size(1)) + .expand(y_padded.size(0), num_pairs, y_padded.size(1)) + .reshape(y_padded.size(0) * num_pairs, y_padded.size(1)) ) - - wer1 = k2.levenshtein_distance( - px=px1, py=y_padded.int(), boundary=boundary + y_expand_lens = ( + y_lens.reshape(y_lens.size(0), 1) + .expand(y_lens.size(0), num_pairs) + .reshape( + y_lens.size(0) * num_pairs, + ) ) - wer1 = torch.gather( - wer1, - 1, - boundary[:, 2] - .reshape(batch_size, 1, 1) - .expand(batch_size, 1, wer1.size(2)), - ).squeeze(1) - wer1 = torch.gather( - wer1, 1, boundary[:, 3].reshape(batch_size, 1) - ).squeeze(1) - - px2 = k2.RaggedTensor(sampled_paths[:, 1, :]) - px2 = px2.remove_values_eq(blank_id) - row_splits = px2.shape.row_splits(1) - px2_lens = row_splits[1:] - row_splits[:-1] - px2 = px2.pad(mode="constant", padding_value=blank_id) - boundary = torch.cat( - [ - torch.zeros((batch_size, 2), dtype=torch.int64, device=device), - px2_lens.reshape(batch_size, 1), - y_lens.reshape(batch_size, 1), - ], - dim=1, + wer1 = self.get_wer( + sampled_paths=sampled_paths[:, :, 0, :].view( + batch_size * num_pairs, path_length + ), + y_padded=y_padded_expand, + y_lens=y_expand_lens, + blank_id=blank_id, ) + wer1 = wer1.view(batch_size, num_pairs) - wer2 = k2.levenshtein_distance( - px=px2, py=y_padded.int(), boundary=boundary + wer2 = self.get_wer( + sampled_paths=sampled_paths[:, :, 1, :].view( + batch_size * num_pairs, path_length + ), + y_padded=y_padded_expand, + y_lens=y_expand_lens, + blank_id=blank_id, ) - wer2 = torch.gather( - wer2, - 1, - boundary[:, 2] - .reshape(batch_size, 1, 1) - .expand(batch_size, 1, wer2.size(2)), - ).squeeze(1) - wer2 = torch.gather( - wer2, 1, boundary[:, 3].reshape(batch_size, 1) - ).squeeze(1) + wer2 = wer2.view(batch_size, num_pairs) - delta_wer = torch.pow( - ((wer1 - wer2) - (pred_output[:, 0] - pred_output[:, 1])), 2 - ) + wer_diff = wer1 - wer2 + pred_wer_diff = pred_wer[:, :, 0] - pred_wer[:, :, 1] - return ( - delta_wer, - torch.abs(wer1 - wer2), - torch.abs(pred_output[:, 0] - pred_output[:, 1]), - ) + return (wer_diff, pred_wer_diff, sampled_joiner, sampled_quasi_joiner) def forward( self, @@ -388,6 +497,7 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + num_pairs: int = 10, path_length: int = 20, warmup: float = 1.0, reduction: str = "sum", @@ -526,54 +636,46 @@ def forward( warmup=warmup, ) - l2_loss = ( - torch.sum(torch.pow(enhanced_embedding - encoder_out.detach(), 2)) - / vocab_size + l2_loss = torch.sum( + torch.pow(enhanced_embedding - encoder_out.detach(), 2) + ) / encoder_out.size(2) + + init_context = self.get_init_contexts( + px_grad=px_grad, py_grad=py_grad, y_padded=y_padded ) - if False: - # Get contexts for each frame according to the gradients, just like we - # do for getting pruning bounds. - (B, S, T1) = px_grad.shape - T = py_grad.shape[-1] - # shape : (B, S, T) - tot_grad = px_grad[:, :, :T] + py_grad[:, :S, :] - # shape : (B, T) - best_idx = torch.argmax(tot_grad, dim=1) - # shape : (B, T, context_size) - state_idx = best_idx.reshape((B, T, 1)).expand( - (B, T, context_size) - ) + torch.arange(context_size, device=px_grad.device) - # shape : (B, context_size) - init_context = torch.tensor( - [blank_id], dtype=torch.int64, device=px_grad.device - ).expand(B, context_size) - # shape : (B, S + context_size) - sos_y_padded = torch.cat([init_context, y_padded], dim=1) - init_context = torch.gather( - sos_y_padded.unsqueeze(1).expand(B, T, S + context_size), - dim=2, - index=state_idx, - ) + # wer_diff, pred_wer_diff : (B, num_pairs) + ( + wer_diff, + pred_wer_diff, + sampled_joiner, + sampled_quasi_joiner, + ) = self.delta_wer( + encoder_out=encoder_out, + enhanced_encoder_out=enhanced_embedding.detach(), + encoder_out_lens=x_lens, + init_context=init_context, + y_padded=y_padded, + y_lens=y_lens, + num_pairs=num_pairs, + path_length=path_length, + ) - delta_wer, wer_diff, pred_wer_diff = self.delta_wer( - encoder_out=encoder_out, - encoder_out_lens=x_lens, - init_context=init_context, - y_padded=y_padded, - y_lens=y_lens, - path_length=path_length, - ) + delta_wer = torch.pow(wer_diff - pred_wer_diff, 2) - return ( - simple_loss, - pruned_loss, - delta_wer, - wer_diff, - pred_wer_diff, - ) + delta_wer_loss = torch.sum(delta_wer) + + predictor_wer_loss = torch.sum( + sampled_joiner * sampled_quasi_joiner.detach() + ) / (num_pairs * path_length * sampled_joiner.size(-1)) - return (simple_loss, pruned_loss, l2_loss) + return ( + simple_loss, + pruned_loss, + delta_wer_loss, + l2_loss, + predictor_wer_loss, + ) class TransformerLM(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py index 7b2ca90a15..e1510c4d54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -370,6 +370,13 @@ def get_parser(): help="The scale applying to l2_loss of embedding and enhanced_embedding", ) + parser.add_argument( + "--predictor-loss-scale", + type=float, + default=0.1, + help="The scale applying to l2_loss of embedding and enhanced_embedding", + ) + add_model_arguments(parser) return parser @@ -441,7 +448,7 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 1000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -690,8 +697,13 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - # simple_loss, pruned_loss, delta_wer, wer_diff, pred_wer_diff = model( - simple_loss, pruned_loss, l2_loss = model( + ( + simple_loss, + pruned_loss, + delta_wer_loss, + l2_loss, + predictor_wer_loss, + ) = model( x=feature, x_lens=feature_lens, y=y, @@ -729,9 +741,10 @@ def compute_loss( simple_loss = simple_loss.sum() pruned_loss = pruned_loss.sum() - # delta_wer = delta_wer.sum() - # wer_diff = wer_diff.sum() - # pred_wer_diff = pred_wer_diff.sum() + + logging.info( + f"simple_loss : {simple_loss}, pruned_loss : {pruned_loss}, delta_wer_loss : {delta_wer_loss}, l2_loss : {l2_loss}, predictor_wer_loss : {predictor_wer_loss}" + ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -742,17 +755,18 @@ def compute_loss( if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) - # l2_loss_scale = ( - # 1.0 - # if warmup < 1.0 - # else (0.1 if warmup > 1.0 and warmup < 2.0 else params.l2_loss_scale)) - l2_loss_scale = params.l2_loss_scale + l2_loss_scale = 0.0 if warmup < 1.0 else params.l2_loss_scale + delta_wer_scale = 0.0 if warmup < 1.0 else params.delta_wer_scale + predictor_loss_scale = ( + 0.0 if warmup < 2.0 else params.predictor_loss_scale + ) loss = ( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + l2_loss_scale * l2_loss - # + params.delta_wer_scale * delta_wer + + delta_wer_scale * delta_wer_loss + + predictor_loss_scale * predictor_wer_loss ) assert loss.requires_grad == is_training @@ -782,9 +796,8 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() info["l2_loss"] = l2_loss.detach().cpu().item() - # info["delta_wer"] = delta_wer.detach().cpu().item() - # info["wer_diff"] = wer_diff.detach().cpu().item() - # info["pred_wer_diff"] = pred_wer_diff.detach().cpu().item() + info["delta_wer_loss"] = delta_wer_loss.detach().cpu().item() + info["predictor_wer_loss"] = predictor_wer_loss.detach().cpu().item() return loss, info @@ -1046,7 +1059,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -1221,6 +1234,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch.autograd.set_detect_anomaly(True) if __name__ == "__main__": main() From d73b1e640d779206ec45aaea613d5b1c2ab8512f Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 16 Nov 2022 14:38:11 +0800 Subject: [PATCH 5/7] refine predictor loss; docs --- .../pruned_transducer_stateless_mbr/model.py | 170 +++++++++++++++--- .../pruned_transducer_stateless_mbr/train.py | 7 +- 2 files changed, 147 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py index 3cdf614b1a..207bfc1390 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -112,6 +112,16 @@ def __init__( (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + quasi_joiner: + It is another joiner to predict expected word error rate, its inputs + and output are the same as ``joiner``. + transformer_lm: + It is a transformer encoder that converts the texts into text + embeddings. + embedding_enhancer: + It is a transformer decoder (has self attention from acoustic + embedding and cross-attention from text embedding) that outputs an + enhanced embedding "knowing" both acoustics and texts. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -136,7 +146,26 @@ def get_wer( y_padded: torch.Tensor, y_lens: torch.Tensor, blank_id: int = 0, - ): + ) -> torch.Tensor: + """Get levenshtein distances between sampled_paths and transcripts + (y_padded). + + Args: + sampled_paths: + A torch Tensor has a shape of (batch_size, path_length), it contains + the paths sampled from transducer. + y_padded: + The padded (with blank_id) transcripts, its shape is the same as + sampled_paths. + y_lens: + The real lens of transcripts before padding, its shape is + (batch_size,). + blank_id: + The blank ID. + Return: + Return a Tensor with the shape of (batch_size,) containing the + levenshtein distance between sampled_paths and transcripts. + """ batch_size, path_length = sampled_paths.shape assert y_padded.size(0) == batch_size, (y_padded.shape, batch_size) assert y_lens.size(0) == batch_size, (y_lens.shape, batch_size) @@ -163,6 +192,7 @@ def get_wer( px=px, py=y_padded.int(), boundary=boundary ) + # wer : (batch_size, U) wer = torch.gather( wer, 1, @@ -171,11 +201,11 @@ def get_wer( .expand(wer.size(0), 1, wer.size(2)), ).squeeze(1) + # wer : (batch_size,) wer = torch.gather( wer, 1, boundary[:, 3].reshape(batch_size, 1) ).squeeze(1) - # wer: (batch_size,) return wer def get_init_contexts( @@ -183,7 +213,21 @@ def get_init_contexts( px_grad: torch.Tensor, py_grad: torch.Tensor, y_padded: torch.Tensor, - ): + ) -> torch.Tensor: + """Get initial left contexts for each frame according to the gradients + of ``px`` and ``py``. + + Args: + px_grad: + The gradients of ``px``, returned by `rnnt_loss_smoothed`. + py_grad: + The gradients of ``py``, returned by `rnnt_loss_smoothed`. + y_padded: + The padded labels for each sequence. + Return: + Return a tensor with the shape of (N, T, context_size) containing the + left contexts for each frame. + """ context_size = self.decoder.context_size blank_id = self.decoder.blank_id # Get contexts for each frame according to the gradients, just like we @@ -222,10 +266,18 @@ def delta_wer( num_pairs: int = 10, path_length: int = 20, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ + """Sample paths from transducer joiner and calculate related variables + needed by delta_wer_loss and predictor_loss. + + TODO:(Wei Kang) Add more docs describing what this function actually + does. + Args: encoder_out: The output of the encoder whose shape is (batch_size, T, encoder_dim) + enhanced_encoder_out: + The enhanced_embedding "knowing" both acoustics and texts, it has + the same shape of ``encoder_out``. encoder_out_lens: A tensor of shape (batch_size,) containing the number of frames before padding. @@ -233,18 +285,23 @@ def delta_wer( A tensor of shape (batch_size, T, context_size) containing the initial history symbols for each frame. y_padded: - The transcripts whose shape is (batch_size, S). + The padded labels of each sequence whose shape is (batch_size, S). y_lens: A tensor of shape (batch_size,) containing the number of symbols before padding. + num_pairs: + The number of pairs of paths to sample for each sequence. path_length: The length of the sampled paths. Returns: - Return three tensors, - - The delta_wer, its shape is (batch_size,) - - The absolute value of wer_diff, its shape is (batch_size,) - - The absolute value of pred_wer_diff, its shape is (batch_size,). + Return four tensors, + - The levenshtein wer difference, its shape is (batch_size, num_pairs) + - The prediction wer difference, its shape is (batch_size, num_pairs) + - The sampled joiner output, its shape is + (batch_size, num_pairs, path_length, vocab_size) + - The sampled quasi joiner output(to predict delta_wer), its shape is + (batch_size, num_pairs, path_length, vocab_size) """ batch_size, T, encoder_dim = encoder_out.shape assert y_padded.size(0) == batch_size, (y_padded.shape, batch_size) @@ -257,13 +314,16 @@ def delta_wer( # t_index contains the frame ids we are sampling for each pair of paths. # shape : (batch_size, num_pairs) + # t_index starts from diverse positions of a sequence. t_index = torch.arange( 0, T + num_pairs, int(T / num_pairs), device=device )[0:num_pairs] t_index = t_index.reshape((1, num_pairs)).expand(batch_size, num_pairs) + # t_index (the start frame idx) can not be larger than sequence length. t_index = torch.remainder(t_index, encoder_out_lens.reshape(-1, 1)) - # The max frame index for each path + # The max frame index for each path, from which we can know whether we + # reaching final frame. # shape : (batch_size, num_pairs) t_index_max = encoder_out_lens.view(batch_size, 1, 1).expand( batch_size, num_pairs, 2 @@ -279,7 +339,8 @@ def delta_wer( ), ) - # we will sample two paths for each sequence + # we will sample two paths for each pair, the two paths start from the + # same frame idx. t_index = t_index.reshape((batch_size, num_pairs, 1)).expand( batch_size, num_pairs, 2 ) @@ -291,17 +352,19 @@ def delta_wer( batch_size * num_pairs * 2, context_size ) - # It has a shape of (batch_size, num_pairs) indicating whether having different - # paths for this sequence. + # It has a shape of (batch_size, num_pairs) indicating whether having + # different paths for this pair. has_diff = torch.zeros((batch_size, num_pairs), device=device).bool() - # It has a shape of (batch_size, num_pairs) indicating whether reaching final - # for this sequence + # It has a shape of (batch_size, num_pairs) indicating whether reaching + # final for this pair reach_final = torch.zeros((batch_size, num_pairs), device=device).bool() - # The pred_wer, default zeros. If there is no different symbol for the - # sampled paths, the pred_wers for the two paths are the same. + # The pred_wer, default zeros. If there is no different symbols for the + # sampled pair of paths, the pred_wers for the two paths are the same. pred_wer = torch.zeros((batch_size, num_pairs, 2), device=device) + # dummy_output is used to fill sampled_joiner and sampled_quasi_joiner + # at padding positions. dummy_output = torch.zeros( (batch_size, num_pairs, vocab_size), device=device ) @@ -354,14 +417,14 @@ def delta_wer( # The two paths have different symbols. mask = index[:, :, 0] != index[:, :, 1] - # shape : (batch_size, num_pairs), will only be True when the two paths have - # different symbols in the first time. + # shape : (batch_size, num_pairs), will only be True when the two + # paths have different symbols at the first time. meet_diff = mask & ~has_diff & ~reach_final has_diff |= mask # wer_output: (B, num_pairs, 2) - wer_output = torch.gather( + pred_wer_output = torch.gather( quasi_joiner_output, dim=3, index=index.unsqueeze(3) ).squeeze(3) @@ -369,7 +432,7 @@ def delta_wer( # to have different symbols. pred_wer = torch.where( meet_diff.reshape(batch_size, num_pairs, 1), - wer_output, + pred_wer_output, pred_wer, ) @@ -427,6 +490,7 @@ def delta_wer( ], dim=3, ) + # if the sampled symbol is blank, we only need to roll the history # symbols, if the sampled symbol is not blank, append the newly # sampled symbol. @@ -442,7 +506,7 @@ def delta_wer( batch_size * num_pairs * 2, context_size ) - # sampled_paths : (batch_size, num_pairs, 2, path_lengths) + # sampled_paths : (batch_size, num_pairs, 2, path_length) sampled_paths = torch.stack(sampled_paths_list, dim=3).int() # sampled_joiner : (batch_size, num_pairs, path_lengths, vocab_size) @@ -512,6 +576,10 @@ def forward( y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. + sos_id: + The id of start of sequence for transformer language model. + eos_id: + The id of end of sequence for transformer language model. prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -521,6 +589,10 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part + num_pairs: + The number of pairs of paths that used to train quasi_joiner. + path_length: + The length of sampling path. warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. @@ -529,7 +601,7 @@ def forward( "none" to return the loss in a 1-D tensor for each utterance in the batch. Returns: - Return the transducer loss. + Return the transducer loss, delta_wer_loss, l2_loss and predictor_loss. Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -645,6 +717,7 @@ def forward( ) # wer_diff, pred_wer_diff : (B, num_pairs) + # sampled_joiner, sampled_quasi_joiner : (B, num_pairs, path_length, V) ( wer_diff, pred_wer_diff, @@ -661,9 +734,7 @@ def forward( path_length=path_length, ) - delta_wer = torch.pow(wer_diff - pred_wer_diff, 2) - - delta_wer_loss = torch.sum(delta_wer) + delta_wer_loss = torch.sum(torch.pow(wer_diff - pred_wer_diff, 2)) predictor_wer_loss = torch.sum( sampled_joiner * sampled_quasi_joiner.detach() @@ -742,6 +813,13 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + sos_id: + The id of start of sequence. + eos_id: + The id of end of sequence. warmup: A floating point value that gradually increases from 0 throughout training; when it is >= 1.0 we are "fully warmed up". It is used @@ -799,6 +877,21 @@ def __init__( dropout: float = 0.1, layer_dropout: float = 0.075, ): + """ + Args: + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + layer_dropout (float): layer-dropout rate. + """ super().__init__() self.encoder_pos = PositionalEncoding(d_model, dropout) decoder_layer = TransformerDecoderLayer( @@ -820,7 +913,30 @@ def forward( text_embedding_key_padding_mask: Optional[torch.Tensor] = None, mask_proportion: float = 0.25, warmup: float = 1.0, - ): + ) -> torch.Tensor: + """ + Args: + embedding: + The acoustic embedding produced by transducer encoder. + Shape: (N, T, encoder_dim) + text_embedding: + The text embedding with the shape of (S, N, E) + embedding_mask: + The mask for the embedding. + text_embedding_mask: + The mask for the text embedding. + embedding_key_padding_mask: + The mask for the embedding keys per batch. + text_embedding_key_padding_mask: + The mask for the text embedding keys per batch. + mask_proportion: + The proportion used to mask embedding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + """ + N, T, C = embedding.shape mask = torch.randn((N, T, C), device=embedding.device) mask = mask > mask_proportion diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py index e1510c4d54..9ac91034a4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -360,21 +360,22 @@ def get_parser(): "--delta-wer-scale", type=float, default=0.1, - help="The scale applying to delta_wer when it adds to the loss", + help="The scale applying to delta_wer_loss when it adds to the loss", ) parser.add_argument( "--l2-loss-scale", type=float, default=0.1, - help="The scale applying to l2_loss of embedding and enhanced_embedding", + help="""The scale applying to l2_loss which predicts enhanced_embedding + to make it knows about acoustics and texts""", ) parser.add_argument( "--predictor-loss-scale", type=float, default=0.1, - help="The scale applying to l2_loss of embedding and enhanced_embedding", + help="The scale applying to predictor_wer_loss", ) add_model_arguments(parser) From cf9607d237203db9cc7b58929cf31faab2309618 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 8 Dec 2022 13:29:07 +0800 Subject: [PATCH 6/7] Add softmax to sampled_joiner output --- .../pruned_transducer_stateless_mbr/decode.py | 59 +++++++++-------- .../pruned_transducer_stateless_mbr/model.py | 40 ++++++------ .../pruned_transducer_stateless_mbr/train.py | 64 ++++++++----------- 3 files changed, 78 insertions(+), 85 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py index 8431492e6d..0c1341ba58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/decode.py @@ -225,6 +225,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -250,7 +251,7 @@ def get_parser(): search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, + fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, ) @@ -260,8 +261,8 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. + Used only when --decoding_method is fast_beam_search_nbest_LG or + fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -270,8 +271,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -279,8 +280,8 @@ def get_parser(): type=int, default=64, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -375,9 +376,10 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -392,14 +394,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -414,7 +415,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -424,8 +428,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -523,8 +531,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -668,6 +676,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -691,8 +700,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -812,7 +821,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py index 207bfc1390..14bee7b277 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -15,22 +15,13 @@ # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Optional, Tuple -import logging import k2 import torch import torch.nn as nn - from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledEmbedding, - ScaledLinear, -) +from scaling import ScaledEmbedding, ScaledLinear from torch.distributions.categorical import Categorical from transformer import ( PositionalEncoding, @@ -41,7 +32,8 @@ decoder_padding_mask, generate_square_subsequent_mask, ) -from icefall.utils import add_sos, add_eos, make_pad_mask + +from icefall.utils import add_eos, add_sos, make_pad_mask def _roll_by_shifts( @@ -308,8 +300,8 @@ def delta_wer( device = encoder_out.device blank_id = self.decoder.blank_id - context_size = self.decoder.context_size decoder_dim = self.decoder_dim + context_size = self.decoder.context_size vocab_size = self.decoder.vocab_size # t_index contains the frame ids we are sampling for each pair of paths. @@ -624,8 +616,6 @@ def forward( y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - context_size = self.decoder.context_size - vocab_size = self.decoder.vocab_size sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. @@ -708,9 +698,14 @@ def forward( warmup=warmup, ) - l2_loss = torch.sum( - torch.pow(enhanced_embedding - encoder_out.detach(), 2) - ) / encoder_out.size(2) + enhanced_embedding = enhanced_embedding.masked_fill( + embedding_key_padding_mask.unsqueeze(2), 0 + ) + encoder_out2 = encoder_out.masked_fill( + embedding_key_padding_mask.unsqueeze(2), 0 + ).detach() + + l2_loss = torch.sum(torch.pow(enhanced_embedding - encoder_out2, 2)) init_context = self.get_init_contexts( px_grad=px_grad, py_grad=py_grad, y_padded=y_padded @@ -736,16 +731,17 @@ def forward( delta_wer_loss = torch.sum(torch.pow(wer_diff - pred_wer_diff, 2)) - predictor_wer_loss = torch.sum( - sampled_joiner * sampled_quasi_joiner.detach() - ) / (num_pairs * path_length * sampled_joiner.size(-1)) + sampled_joiner = sampled_joiner.softmax(dim=3) + predictor_loss = torch.sum( + (sampled_joiner * sampled_quasi_joiner.detach()).sum(dim=3) + ) return ( simple_loss, pruned_loss, delta_wer_loss, l2_loss, - predictor_wer_loss, + predictor_loss, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py index 9ac91034a4..4bd8054637 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/train.py @@ -22,38 +22,13 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless4/train.py \ +./pruned_transducer_stateless_mbr/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless_mbr/exp \ --full-libri 1 \ --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless4/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 - -# train a streaming model -./pruned_transducer_stateless4/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless4/exp \ - --full-libri 1 \ - --dynamic-chunk-training 1 \ - --causal-convolution 1 \ - --short-chunk-size 25 \ - --num-left-chunks 4 \ - --max-duration 300 - """ import argparse @@ -77,7 +52,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import Transducer, TransformerLM, EmbeddingEnhancer +from model import EmbeddingEnhancer, Transducer, TransformerLM from optim import Eden, Eve from torch import Tensor from torch.cuda.amp import GradScaler @@ -145,7 +120,8 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--num-lm-layers", type=int, default=3, - help="The number of layers for transformer language model", + help="""The number of layers for transformer language model (to get the + enhanced_embedding).""", ) parser.add_argument( @@ -356,6 +332,13 @@ def get_parser(): help="The length of the sampled paths for MBR training.", ) + parser.add_argument( + "--num-path-pairs", + type=int, + default=10, + help="The number of pairs of paths use to calculate the delta_wer_loss.", + ) + parser.add_argument( "--delta-wer-scale", type=float, @@ -375,7 +358,7 @@ def get_parser(): "--predictor-loss-scale", type=float, default=0.1, - help="The scale applying to predictor_wer_loss", + help="The scale applying to predictor_loss", ) add_model_arguments(parser) @@ -703,7 +686,7 @@ def compute_loss( pruned_loss, delta_wer_loss, l2_loss, - predictor_wer_loss, + predictor_loss, ) = model( x=feature, x_lens=feature_lens, @@ -713,6 +696,7 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + num_pairs=params.num_path_pairs, path_length=params.path_length, warmup=warmup, reduction="none", @@ -743,10 +727,6 @@ def compute_loss( simple_loss = simple_loss.sum() pruned_loss = pruned_loss.sum() - logging.info( - f"simple_loss : {simple_loss}, pruned_loss : {pruned_loss}, delta_wer_loss : {delta_wer_loss}, l2_loss : {l2_loss}, predictor_wer_loss : {predictor_wer_loss}" - ) - # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -756,10 +736,18 @@ def compute_loss( if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + l2_loss_scale = 0.0 if warmup < 1.0 else params.l2_loss_scale delta_wer_scale = 0.0 if warmup < 1.0 else params.delta_wer_scale + predictor_loss_scale = ( - 0.0 if warmup < 2.0 else params.predictor_loss_scale + 0.0 + if warmup < 1.0 + else ( + params.predictor_loss_scale * 0.1 + if warmup > 1.0 and warmup < 2.0 + else params.predictor_loss_scale + ) ) loss = ( @@ -767,7 +755,7 @@ def compute_loss( + pruned_loss_scale * pruned_loss + l2_loss_scale * l2_loss + delta_wer_scale * delta_wer_loss - + predictor_loss_scale * predictor_wer_loss + + predictor_loss_scale * predictor_loss ) assert loss.requires_grad == is_training @@ -798,7 +786,7 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() info["l2_loss"] = l2_loss.detach().cpu().item() info["delta_wer_loss"] = delta_wer_loss.detach().cpu().item() - info["predictor_wer_loss"] = predictor_wer_loss.detach().cpu().item() + info["predictor_loss"] = predictor_loss.detach().cpu().item() return loss, info From 9252276b4fef0771d3d1c8901e3b45fb9a50c112 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 8 Dec 2022 15:19:42 +0800 Subject: [PATCH 7/7] Add some comments --- .../ASR/pruned_transducer_stateless_mbr/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py index 14bee7b277..6f90ad8f7e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_mbr/model.py @@ -261,8 +261,13 @@ def delta_wer( """Sample paths from transducer joiner and calculate related variables needed by delta_wer_loss and predictor_loss. - TODO:(Wei Kang) Add more docs describing what this function actually - does. + The sampling process is just like a frame synchronized decoding. For + each sequence, `2 * num_pairs` of paths will be sampled, and we won't + sample the paths from frame 0 to frame T, `num_pairs` of start points + will be chosen (evenly distributed between 0 and T), from each start + point, two paths will be sampled and the longest path length is + `path_length` (some will be shorter because of reaching final frame). + Args: encoder_out: @@ -292,7 +297,7 @@ def delta_wer( - The prediction wer difference, its shape is (batch_size, num_pairs) - The sampled joiner output, its shape is (batch_size, num_pairs, path_length, vocab_size) - - The sampled quasi joiner output(to predict delta_wer), its shape is + - The sampled quasi joiner output(the expected wer), its shape is (batch_size, num_pairs, path_length, vocab_size) """ batch_size, T, encoder_dim = encoder_out.shape @@ -470,6 +475,9 @@ def delta_wer( reach_final | final_mask[:, :, 0] | final_mask[:, :, 1] ) + # When reaching final, reset the sampling frame to 0. + # just to make it run normally, we won't use the sampled symbols + # anymore. t_index.masked_fill_(final_mask, 0) left_symbols = left_symbols.view(