Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
Expand All @@ -39,6 +38,7 @@
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_recognition import K2SpeechRecognitionDataset
from torch.utils.data import DataLoader

from icefall.utils import str2bool
Expand Down Expand Up @@ -232,8 +232,11 @@ def train_dataloaders(
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")

# We use probability 1.0 here so that musan augmentation is
# always performed
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=1.0, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
Expand Down
224 changes: 224 additions & 0 deletions egs/librispeech/ASR/zipformer/speech_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Callable, Dict, List, Union

import torch
from torch.utils.data.dataloader import DataLoader, default_collate

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix


class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech recognition task using k2 library.

This dataset expects to be queried with lists of cut IDs,
for which it loads features and automatically collates/batches them.

To use it with a PyTorch DataLoader, set ``batch_size=None``
and provide a :class:`SimpleCutSampler` sampler.

Each item in this dataset is a dict of:

.. code-block::

{
'inputs': float tensor with shape determined by :attr:`input_strategy`:
- single-channel:
- features: (B, T, F)
- audio: (B, T)
- multi-channel: currently not supported
'supervisions': [
{
'sequence_idx': Tensor[int] of shape (S,)
'text': List[str] of len S

# For feature input strategies
'start_frame': Tensor[int] of shape (S,)
'num_frames': Tensor[int] of shape (S,)

# For audio input strategies
'start_sample': Tensor[int] of shape (S,)
'num_samples': Tensor[int] of shape (S,)

# Optionally, when return_cuts=True
'cut': List[AnyCut] of len S
}
]
}

Dimension symbols legend:
* ``B`` - batch size (number of Cuts)
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
* ``T`` - number of frames of the longest Cut
* ``F`` - number of features

The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
"""

def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
):
"""
k2 ASR IterableDataset constructor.

:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy

# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
# reset the internal HDF5 caches.
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)

def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_duration and max_cuts.
"""
validate_for_asr(cuts)

self.hdf5_fix.update()

# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)

if self.cut_transforms:
orig_cuts = cuts

cuts = cuts.repeat(times=2)

for tnfm in self.cut_transforms:
cuts = tnfm(cuts)

cuts = orig_cuts + cuts

# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we successfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl

# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)

# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)

batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]

has_word_alignments = all(
s.alignment is not None and "word" in s.alignment
for c in cuts
for s in c.supervisions
)
if has_word_alignments:
# TODO: might need to refactor BatchIO API to move the following conditional logic
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
# that returns either num_frames or num_samples depending on the strategy).
words, starts, ends = [], [], []
frame_shift = cuts[0].frame_shift
sampling_rate = cuts[0].sampling_rate
if frame_shift is None:
try:
frame_shift = self.input_strategy.extractor.frame_shift
except AttributeError:
raise ValueError(
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
)
for c in cuts:
for s in c.supervisions:
words.append([aliword.symbol for aliword in s.alignment["word"]])
starts.append(
[
compute_num_frames(
aliword.start,
frame_shift=frame_shift,
sampling_rate=sampling_rate,
)
for aliword in s.alignment["word"]
]
)
ends.append(
[
compute_num_frames(
aliword.end,
frame_shift=frame_shift,
sampling_rate=sampling_rate,
)
for aliword in s.alignment["word"]
]
)
batch["supervisions"]["word"] = words
batch["supervisions"]["word_start"] = starts
batch["supervisions"]["word_end"] = ends

return batch


def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)

# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
Loading