diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 1b6d0026ff..23e9b5c161 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -44,6 +44,7 @@ RUN pip install --no-cache-dir \ espnet_tts_frontend \ graphviz \ kaldi-decoder \ + kaldi_native_fbank \ kaldi_native_io \ kaldialign \ kaldifst \ @@ -61,6 +62,7 @@ RUN pip install --no-cache-dir \ piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \ pypinyin==0.50.0 \ pytest \ + rknn_toolkit2 \ sentencepiece>=0.1.96 \ six \ tensorboard \ diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ed0e62330c..6f2e3d0365 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -102,6 +102,7 @@ jobs: cd ../transducer_lstm pytest -v -s + pip install kaldi_native_fbank rknn_toolkit2 cd ../zipformer pytest -v -s diff --git a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py index 4a7e2b1c10..aedd37765d 120000 --- a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py +++ b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py @@ -1 +1 @@ -/exp/ahussein/tmp/icefall/egs/iwslt22_ta/ST/local/prepare_transcripts.py \ No newline at end of file +../../ST/local/prepare_transcripts.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/prepare_transcripts.py b/egs/iwslt22_ta/ST/local/prepare_transcripts.py index c4e1398299..7805fc489d 100755 --- a/egs/iwslt22_ta/ST/local/prepare_transcripts.py +++ b/egs/iwslt22_ta/ST/local/prepare_transcripts.py @@ -1,66 +1,66 @@ -# Copyright 2023 Johns Hopkins University (Amir Hussein) - -#!/usr/bin/python -""" -This script prepares transcript_words.txt from cutset -""" - -from lhotse import CutSet -import argparse -import logging -import pdb -from pathlib import Path -import os - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--cut", - type=str, - default="", - help="Cutset file", - ) - parser.add_argument( - "--src-langdir", - type=str, - default="", - help="name of the source lang-dir", - ) - parser.add_argument( - "--tgt-langdir", - type=str, - default=None, - help="name of the target lang-dir", - ) - return parser - - -def main(): - - parser = get_parser() - args = parser.parse_args() - - logging.info("Reading the cuts") - cuts = CutSet.from_file(args.cut) - if args.tgt_langdir != None: - logging.info("Target dir is not None") - langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)] - else: - langdirs = [Path(args.src_langdir)] - - for langdir in langdirs: - if not os.path.exists(langdir): - os.makedirs(langdir) - - with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt: - for c in cuts: - src_txt = c.supervisions[0].text - tgt_txt = c.supervisions[0].custom['translated_text']['eng'] - src.write(src_txt + '\n') - tgt.write(tgt_txt + '\n') - -if __name__ == "__main__": - main() \ No newline at end of file +# Copyright 2023 Johns Hopkins University (Amir Hussein) + +#!/usr/bin/python +""" +This script prepares transcript_words.txt from cutset +""" + +from lhotse import CutSet +import argparse +import logging +import pdb +from pathlib import Path +import os + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--cut", + type=str, + default="", + help="Cutset file", + ) + parser.add_argument( + "--src-langdir", + type=str, + default="", + help="name of the source lang-dir", + ) + parser.add_argument( + "--tgt-langdir", + type=str, + default=None, + help="name of the target lang-dir", + ) + return parser + + +def main(): + + parser = get_parser() + args = parser.parse_args() + + logging.info("Reading the cuts") + cuts = CutSet.from_file(args.cut) + if args.tgt_langdir != None: + logging.info("Target dir is not None") + langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)] + else: + langdirs = [Path(args.src_langdir)] + + for langdir in langdirs: + if not os.path.exists(langdir): + os.makedirs(langdir) + + with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt: + for c in cuts: + src_txt = c.supervisions[0].text + tgt_txt = c.supervisions[0].custom['translated_text']['eng'] + src.write(src_txt + '\n') + tgt.write(tgt_txt + '\n') + +if __name__ == "__main__": + main() diff --git a/egs/iwslt22_ta/ST/zipformer/beam_search.py b/egs/iwslt22_ta/ST/zipformer/beam_search.py index 1eaa380497..b15a6dbf72 100644 --- a/egs/iwslt22_ta/ST/zipformer/beam_search.py +++ b/egs/iwslt22_ta/ST/zipformer/beam_search.py @@ -1033,7 +1033,7 @@ def modified_beam_search( nb_shift = logp_b - logits[..., 0] nb_shift = nb_shift.unsqueeze(-1) log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) else: log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) @@ -1203,7 +1203,7 @@ def modified_beam_search_hat( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - + # For blank symbol, log-prob is log-sigmoid of the score logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) # Additionally, to ensure the the probs of blank and non-blank sum to 1, we @@ -1211,9 +1211,9 @@ def modified_beam_search_hat( # is equivalent to log(1 - sigmoid(logits[..., 0])). breakpoint() nb_shift = logp_b - logits[..., 0] - + nb_shift = nb_shift.unsqueeze(-1) log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b, log_probs), dim=-1) + log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) diff --git a/egs/iwslt22_ta/ST/zipformer/profile.py b/egs/iwslt22_ta/ST/zipformer/profile.py index c93adbd143..70af99d1f1 120000 --- a/egs/iwslt22_ta/ST/zipformer/profile.py +++ b/egs/iwslt22_ta/ST/zipformer/profile.py @@ -1 +1 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file +../../ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/train.py b/egs/iwslt22_ta/ST/zipformer/train.py index 4b661d93ef..e669a3b758 100755 --- a/egs/iwslt22_ta/ST/zipformer/train.py +++ b/egs/iwslt22_ta/ST/zipformer/train.py @@ -335,7 +335,7 @@ def get_parser(): files, e.g., checkpoints, log, etc, are saved """, ) - + parser.add_argument( "--bpe-tgt-model", type=str, @@ -792,9 +792,9 @@ def compute_loss( texts = batch["supervisions"]["text"] tgt_texts = batch["supervisions"]["tgt_text"]['eng'] - y = sp.encode(texts, out_type=int) + #y = sp.encode(texts, out_type=int) y_tgt = sp_tgt.encode(tgt_texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + #y = k2.RaggedTensor(y).to(device) y_tgt = k2.RaggedTensor(y_tgt).to(device) with torch.set_grad_enabled(is_training): @@ -817,7 +817,7 @@ def compute_loss( f"simple_loss: {simple_loss}\n" f"pruned_loss: {pruned_loss}" ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] @@ -985,7 +985,7 @@ def save_bad_model(suffix: str = ""): continue except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise if params.print_diagnostics and batch_idx == 5: @@ -1314,7 +1314,7 @@ def remove_short_and_long_text(c: Cut): def display_and_save_batch( batch: dict, params: AttributeDict, - sp: spm.SentencePieceProcessor, + #sp: spm.SentencePieceProcessor, sp_tgt: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1339,7 +1339,7 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") - y = sp.encode(supervisions["text"], out_type=int) + #y = sp.encode(supervisions["text"], out_type=int) y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) num_tokens = sum(len(i) for i in y_tgt) logging.info(f"num tokens: {num_tokens}") @@ -1380,7 +1380,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + display_and_save_batch(batch, params=params, sp_tgt=sp_tgt) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" diff --git a/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py b/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py index 075065f49a..02e4a10d43 100755 --- a/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py +++ b/egs/multi_conv_zh_es_ta/ST/hent_srt/export.py @@ -168,7 +168,7 @@ ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/multi_conv_zh_es_ta/ST - + ./hent_srt/decode.py \ --epoch 9999 --avg 1 --use-averaged-model 0 \ --beam-size 20 \ @@ -208,7 +208,7 @@ ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/multi_conv_zh_es_ta/ST - + ./hent_srt/decode.py \ --epoch 9999 --avg 1 --use-averaged-model 0 \ --causal 1 \ @@ -240,7 +240,7 @@ --st-blank-penalty 2 \ --chunk-size 64 \ --left-context-frames 128 \ - --use-hat False --max-sym-per-frame 20 + --use-hat False --max-sym-per-frame 20 Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at @@ -389,7 +389,7 @@ def forward( features: (N, T, C) feature_lengths: (N,) """ - encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = model.forward_encoder(feature, feature_lengths) + encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = self.model.forward_encoder(features, feature_lengths) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) st_encoder_out = st_encoder_out.permute(1, 0, 2) return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens @@ -596,7 +596,7 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - + ), strict=False ) diff --git a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py index b0dc751582..34a9572e3f 100755 --- a/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py +++ b/egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py @@ -27,7 +27,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" -# For non-streaming model training: +# For non-streaming model training: ./zipformer_hat_st/train.py \ --base-lr 0.045 \ @@ -50,7 +50,7 @@ --warm-step 10000 \ --lr-epochs 6 \ --use-hat False - + # With Cr-CTC ./zipformer_hat_st/train.py \ --base-lr 0.045 \ @@ -989,7 +989,7 @@ def load_checkpoint_if_available( return None assert filename.is_file(), f"{filename} does not exist!" - + saved_params = load_checkpoint( filename, model=model, @@ -1094,7 +1094,7 @@ def compute_loss( spec_augment: The SpecAugment instance used only when use_cr_ctc is True. """ - + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -1116,13 +1116,13 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) if params.st_scale != 1: - alpha_st = params.st_scale - alpha_asr = 1-params.st_scale + alpha_st = params.st_scale + alpha_asr = 1-params.st_scale else: alpha_st, alpha_asr = 1, 1 use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc use_spec_aug = (use_asr_cr_ctc or use_st_cr_ctc) and is_training - + if use_spec_aug: supervision_intervals = batch["supervisions"] supervision_segments = torch.stack( @@ -1218,7 +1218,7 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_st_joiner: info["st_simple_loss"] = st_simple_loss.detach().cpu().item() - info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() + info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.use_asr_cr_ctc: @@ -1573,7 +1573,7 @@ def run(rank, world_size, args): ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - + # if checkpoints and "optimizer" in checkpoints: # logging.info("Loading optimizer state dict") # optimizer.load_state_dict(checkpoints["optimizer"]) @@ -1608,7 +1608,7 @@ def remove_short_and_long_utt(c: Cut): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration =< 0.1 or c.duration >= 30.0: + if c.duration <= 0.1 or c.duration >= 30.0: # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) @@ -1646,7 +1646,7 @@ def remove_short_and_long_utt(c: Cut): # f"Number of tokens: {len(st_tokens)}" # ) return False - + if params.use_asr_cr_ctc: T = ((c.num_frames - 7) // 2 + 1) // 2 tokens = sp.encode(c.supervisions[0].text, out_type=str) diff --git a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py deleted file mode 120000 index 8c203406b8..0000000000 --- a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py \ No newline at end of file diff --git a/icefall/utils.py b/icefall/utils.py index 0d4e24db53..8612e2e774 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -648,7 +648,7 @@ def store_translations( hyp_list = [] ref_list = [] dir_ = os.path.dirname(filename) - reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) + reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename))) refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename))) hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename))) bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename))) @@ -661,7 +661,7 @@ def store_translations( print(f"{cut_id}: ref_tgt {ref_tgt}", file=f) print(f"{cut_id}: hyp {hyp}", file=f) print("\n", file=f) - + print(f"{ref}", file=f_src) print(f"{ref_tgt}", file=f_tgt) @@ -673,7 +673,7 @@ def store_translations( with open(bleu_file, 'w') as b: print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b) print(f"BLEU signiture: {str(bleu.get_signature())}", file=b) - + logging.info( f"[{bleu.corpus_score(hyp_list, [ref_list])}] " f"BLEU signiture: {str(bleu.get_signature())}"