Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/scripts/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ RUN pip install --no-cache-dir \
espnet_tts_frontend \
graphviz \
kaldi-decoder \
kaldi_native_fbank \
kaldi_native_io \
kaldialign \
kaldifst \
Expand All @@ -61,6 +62,7 @@ RUN pip install --no-cache-dir \
piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
pypinyin==0.50.0 \
pytest \
rknn_toolkit2 \
sentencepiece>=0.1.96 \
six \
tensorboard \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ jobs:
cd ../transducer_lstm
pytest -v -s

pip install kaldi_native_fbank rknn_toolkit2
cd ../zipformer
pytest -v -s

Expand Down
2 changes: 1 addition & 1 deletion egs/iwslt22_ta/ASR/local/prepare_transcripts.py
132 changes: 66 additions & 66 deletions egs/iwslt22_ta/ST/local/prepare_transcripts.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,66 @@
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
"""
This script prepares transcript_words.txt from cutset
"""
from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--src-langdir",
type=str,
default="",
help="name of the source lang-dir",
)
parser.add_argument(
"--tgt-langdir",
type=str,
default=None,
help="name of the target lang-dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
if args.tgt_langdir != None:
logging.info("Target dir is not None")
langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)]
else:
langdirs = [Path(args.src_langdir)]
for langdir in langdirs:
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
for c in cuts:
src_txt = c.supervisions[0].text
tgt_txt = c.supervisions[0].custom['translated_text']['eng']
src.write(src_txt + '\n')
tgt.write(tgt_txt + '\n')
if __name__ == "__main__":
main()
# Copyright 2023 Johns Hopkins University (Amir Hussein)

#!/usr/bin/python
Comment on lines +1 to +3

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Shebang should be at the beginning of the file.

The shebang (#!/usr/bin/python) on line 3 should come before the copyright comment on line 1 to be recognized by the shell.

Proposed fix
+#!/usr/bin/python
 # Copyright 2023 Johns Hopkins University  (Amir Hussein)
-
-#!/usr/bin/python
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
#!/usr/bin/python
# Copyright 2023 Johns Hopkins University (Amir Hussein)
🧰 Tools
🪛 Ruff (0.14.13)

3-3: Shebang should be at the beginning of the file

(EXE005)

🤖 Prompt for AI Agents
In `@egs/iwslt22_ta/ST/local/prepare_transcripts.py` around lines 1 - 3, Move the
shebang line (#!/usr/bin/python) to the very top of the file with no preceding
bytes or comments so the shell can recognize it; place the copyright/comment
lines (e.g., "Copyright 2023 Johns Hopkins University  (Amir Hussein)")
immediately after the shebang and ensure there are no blank lines or other
characters before the shebang.

"""
This script prepares transcript_words.txt from cutset
"""

from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--src-langdir",
type=str,
default="",
help="name of the source lang-dir",
)
parser.add_argument(
"--tgt-langdir",
type=str,
default=None,
help="name of the target lang-dir",
)
return parser


def main():

parser = get_parser()
args = parser.parse_args()

logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
if args.tgt_langdir != None:
logging.info("Target dir is not None")
langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)]
else:
langdirs = [Path(args.src_langdir)]

for langdir in langdirs:
if not os.path.exists(langdir):
os.makedirs(langdir)

with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
for c in cuts:
src_txt = c.supervisions[0].text
tgt_txt = c.supervisions[0].custom['translated_text']['eng']
src.write(src_txt + '\n')
tgt.write(tgt_txt + '\n')
Comment on lines +51 to +63

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

IndexError when tgt_langdir is None.

When args.tgt_langdir is None, langdirs contains only one element (line 52), but line 58 unconditionally accesses langdirs[1], causing an IndexError.

Proposed fix: handle single langdir case
-    with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
-        for c in cuts:
-            src_txt = c.supervisions[0].text
-            tgt_txt = c.supervisions[0].custom['translated_text']['eng']
-            src.write(src_txt + '\n')
-            tgt.write(tgt_txt + '\n')
+    if args.tgt_langdir is not None:
+        with open(langdirs[0] / "transcript_words.txt", 'w') as src, \
+             open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
+            for c in cuts:
+                src_txt = c.supervisions[0].text
+                tgt_txt = c.supervisions[0].custom['translated_text']['eng']
+                src.write(src_txt + '\n')
+                tgt.write(tgt_txt + '\n')
+    else:
+        with open(langdirs[0] / "transcript_words.txt", 'w') as src:
+            for c in cuts:
+                src_txt = c.supervisions[0].text
+                src.write(src_txt + '\n')
🤖 Prompt for AI Agents
In `@egs/iwslt22_ta/ST/local/prepare_transcripts.py` around lines 51 - 63, The
code builds langdirs from args.src_langdir and optionally args.tgt_langdir but
then always accesses langdirs[1], causing an IndexError when args.tgt_langdir is
None; update the block that opens transcript files to handle the single-langdir
case: if len(langdirs) == 1, open only langdirs[0] / "transcript_words.txt" and
write only src_txt for each cut (using c.supervisions[0].text), otherwise open
both langdirs[0] and langdirs[1] and write src_txt and tgt_txt
(c.supervisions[0].custom['translated_text']['eng']) respectively. Ensure you
reference the same variables (langdirs, cuts, src_txt, tgt_txt,
args.tgt_langdir) so the change integrates cleanly.


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions egs/iwslt22_ta/ST/zipformer/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def modified_beam_search(
nb_shift = logp_b - logits[..., 0]
nb_shift = nb_shift.unsqueeze(-1)
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1)
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
log_probs.add_(ys_log_probs)
else:
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
Expand Down Expand Up @@ -1203,17 +1203,17 @@ def modified_beam_search_hat(

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)


# For blank symbol, log-prob is log-sigmoid of the score
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])
# Additionally, to ensure the the probs of blank and non-blank sum to 1, we
# need to add the following term to the log-probs of non-blank symbols. This
# is equivalent to log(1 - sigmoid(logits[..., 0])).
breakpoint()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove debug breakpoint() before merging.

This breakpoint() will halt execution and drop into the debugger in production, breaking any decoding that uses modified_beam_search_hat.

🐛 Proposed fix
-        breakpoint()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
breakpoint()
🧰 Tools
🪛 Ruff (0.14.13)

1212-1212: Trace found: breakpoint used

(T100)

🤖 Prompt for AI Agents
In `@egs/iwslt22_ta/ST/zipformer/beam_search.py` at line 1212, Remove the stray
debug call `breakpoint()` inside the beam search implementation (the one found
in `modified_beam_search_hat`), delete that line so decoding won't drop into a
debugger in production, and scan the file for any other leftover `breakpoint()`
or debugging statements to remove; then run the unit/decoding tests or lint to
confirm no runtime interruptions remain.

nb_shift = logp_b - logits[..., 0]

nb_shift = nb_shift.unsqueeze(-1)
log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1)
log_probs = torch.cat((logp_b, log_probs), dim=-1)
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
log_probs.add_(ys_log_probs)

vocab_size = log_probs.size(-1)
Expand Down
2 changes: 1 addition & 1 deletion egs/iwslt22_ta/ST/zipformer/profile.py
16 changes: 8 additions & 8 deletions egs/iwslt22_ta/ST/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def get_parser():
files, e.g., checkpoints, log, etc, are saved
""",
)

parser.add_argument(
"--bpe-tgt-model",
type=str,
Expand Down Expand Up @@ -792,9 +792,9 @@ def compute_loss(

texts = batch["supervisions"]["text"]
tgt_texts = batch["supervisions"]["tgt_text"]['eng']
y = sp.encode(texts, out_type=int)
#y = sp.encode(texts, out_type=int)
y_tgt = sp_tgt.encode(tgt_texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
#y = k2.RaggedTensor(y).to(device)
y_tgt = k2.RaggedTensor(y_tgt).to(device)

with torch.set_grad_enabled(is_training):
Expand All @@ -817,7 +817,7 @@ def compute_loss(
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt)
display_and_save_batch(batch, params=params, sp_tgt=sp_tgt)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]

Expand Down Expand Up @@ -985,7 +985,7 @@ def save_bad_model(suffix: str = ""):
continue
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params, sp_tgt=sp_tgt)
raise

if params.print_diagnostics and batch_idx == 5:
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def remove_short_and_long_text(c: Cut):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
#sp: spm.SentencePieceProcessor,
sp_tgt: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Expand All @@ -1339,7 +1339,7 @@ def display_and_save_batch(

logging.info(f"features shape: {features.shape}")

y = sp.encode(supervisions["text"], out_type=int)
#y = sp.encode(supervisions["text"], out_type=int)
y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int)
num_tokens = sum(len(i) for i in y_tgt)
logging.info(f"num tokens: {num_tokens}")
Expand Down Expand Up @@ -1380,7 +1380,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt)
display_and_save_batch(batch, params=params, sp_tgt=sp_tgt)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
Expand Down
10 changes: 5 additions & 5 deletions egs/multi_conv_zh_es_ta/ST/hent_srt/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
ln -s pretrained.pt epoch-9999.pt

cd /path/to/egs/multi_conv_zh_es_ta/ST

./hent_srt/decode.py \
--epoch 9999 --avg 1 --use-averaged-model 0 \
--beam-size 20 \
Expand Down Expand Up @@ -208,7 +208,7 @@
ln -s pretrained.pt epoch-9999.pt

cd /path/to/egs/multi_conv_zh_es_ta/ST

./hent_srt/decode.py \
--epoch 9999 --avg 1 --use-averaged-model 0 \
--causal 1 \
Expand Down Expand Up @@ -240,7 +240,7 @@
--st-blank-penalty 2 \
--chunk-size 64 \
--left-context-frames 128 \
--use-hat False --max-sym-per-frame 20
--use-hat False --max-sym-per-frame 20

Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
Expand Down Expand Up @@ -389,7 +389,7 @@ def forward(
features: (N, T, C)
feature_lengths: (N,)
"""
encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = model.forward_encoder(feature, feature_lengths)
encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens = self.model.forward_encoder(features, feature_lengths)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
st_encoder_out = st_encoder_out.permute(1, 0, 2)
return encoder_out, encoder_out_lens, st_encoder_out, st_encoder_out_lens
Expand Down Expand Up @@ -596,7 +596,7 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,

), strict=False
)

Expand Down
22 changes: 11 additions & 11 deletions egs/multi_conv_zh_es_ta/ST/zipformer_multijoiner_st/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

export CUDA_VISIBLE_DEVICES="0,1,2,3"

# For non-streaming model training:
# For non-streaming model training:

./zipformer_hat_st/train.py \
--base-lr 0.045 \
Expand All @@ -50,7 +50,7 @@
--warm-step 10000 \
--lr-epochs 6 \
--use-hat False

# With Cr-CTC
./zipformer_hat_st/train.py \
--base-lr 0.045 \
Expand Down Expand Up @@ -989,7 +989,7 @@ def load_checkpoint_if_available(
return None

assert filename.is_file(), f"{filename} does not exist!"

saved_params = load_checkpoint(
filename,
model=model,
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def compute_loss(
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
"""

device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
Expand All @@ -1116,13 +1116,13 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
if params.st_scale != 1:
alpha_st = params.st_scale
alpha_asr = 1-params.st_scale
alpha_st = params.st_scale
alpha_asr = 1-params.st_scale
else:
alpha_st, alpha_asr = 1, 1
use_asr_cr_ctc, use_st_cr_ctc = params.use_asr_cr_ctc, params.use_st_cr_ctc
use_spec_aug = (use_asr_cr_ctc or use_st_cr_ctc) and is_training

if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_st_joiner:
info["st_simple_loss"] = st_simple_loss.detach().cpu().item()
info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item()
info["st_pruned_loss"] = st_pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_asr_cr_ctc:
Expand Down Expand Up @@ -1573,7 +1573,7 @@ def run(rank, world_size, args):
)

scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)

# if checkpoints and "optimizer" in checkpoints:
# logging.info("Loading optimizer state dict")
# optimizer.load_state_dict(checkpoints["optimizer"])
Expand Down Expand Up @@ -1608,7 +1608,7 @@ def remove_short_and_long_utt(c: Cut):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration =< 0.1 or c.duration >= 30.0:
if c.duration <= 0.1 or c.duration >= 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
Expand Down Expand Up @@ -1646,7 +1646,7 @@ def remove_short_and_long_utt(c: Cut):
# f"Number of tokens: {len(st_tokens)}"
# )
return False

if params.use_asr_cr_ctc:
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
Expand Down
Loading
Loading