From 905bf9983baa0f6fb7d4770f9acbb85180798b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 1 May 2026 16:20:57 -0700 Subject: [PATCH 01/30] First draft of indexed lhotse datasets integration + checkpointable dataloader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- CLAUDE.md | 2 + examples/speechlm2/salm_train.py | 3 ++ nemo/collections/common/data/lhotse/cutset.py | 4 +- .../common/data/lhotse/dataloader.py | 48 +++++++++++-------- nemo/collections/speechlm2/data/datamodule.py | 31 +++++++++--- 5 files changed, 59 insertions(+), 29 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 7a1cd849f9e7..74e05265fc92 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,6 +24,8 @@ Requires Python 3.10+, PyTorch 2.6+. - Check: `python setup.py style --scope ` - Fix: `python setup.py style --scope --fix` - **Incremental reformatting**: most collections are excluded from black (see `extend-exclude` in pyproject.toml). The files are reformatted when somebody makes changes to avoid a single big reformatting PR. Do not reformat files outside your changes. +- **Helper placement**: keep public APIs and top-level classes/functions near the top of a file; place private + helpers and utilities at the bottom of the file unless a local module convention requires otherwise. ## Testing diff --git a/examples/speechlm2/salm_train.py b/examples/speechlm2/salm_train.py index 4fddb61985de..05a013c69d86 100644 --- a/examples/speechlm2/salm_train.py +++ b/examples/speechlm2/salm_train.py @@ -51,6 +51,9 @@ def train(cfg): trainer.fit(model, datamodule) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + if __name__ == "__main__": train() diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 8e613091b7e1..48e69001bccc 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -715,7 +715,9 @@ def read_lhotse_manifest(config) -> tuple[CutSet, bool]: else: # Regular Lhotse manifest points to individual audio files (like native NeMo manifest). path = config.cuts_path - cuts = CutSet.from_file(path).map(partial(resolve_relative_paths, manifest_path=path)) + cuts = CutSet.from_file(path, indexed=config.get("indexed", None)).map( + partial(resolve_relative_paths, manifest_path=path) + ) return cuts, is_tarred diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5af5f5d004d7..f062978cc151 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -254,6 +254,16 @@ class LhotseDataLoadingConfig: # The first K examples will actually be read and then discarded, incurring the IO cost, due to # our support of object stores and gzipped files that generally don't have indexes of byte offsets per line. slice_length: Optional[int] = None + # Forwarded to ``CutSet.from_file(path, indexed=...)`` for plain JSONL ``cuts_path`` inputs. + # ``None`` = lhotse auto-detect (uses .idx if present, falls back to streaming). + # ``True`` = require indexed reads (errors if .idx is missing). + # ``False`` = streaming reads only. + indexed: Optional[bool] = None + # When True, build the dataloader with ``torchdata.stateful_dataloader.StatefulDataLoader`` + # instead of ``torch.utils.data.DataLoader``. Combined with a checkpointable lhotse sampler + # (DynamicBucketingSampler / DynamicCutSampler), this enables exact resume from the next batch + # within the current epoch via the standard PyTorch state_dict / load_state_dict protocol. + use_stateful_dataloader: bool = False def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: @@ -265,6 +275,18 @@ def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfi return use_iterable_dataset +def _build_dataloader(use_stateful_dataloader: bool, **kwargs) -> torch.utils.data.DataLoader: + """ + Construct a DataLoader, optionally using ``torchdata.stateful_dataloader.StatefulDataLoader`` + so that resume picks up at the exact next batch via ``state_dict()`` / ``load_state_dict()``. + """ + if use_stateful_dataloader: + from torchdata.stateful_dataloader import StatefulDataLoader + + return StatefulDataLoader(**kwargs) + return torch.utils.data.DataLoader(**kwargs) + + def get_lhotse_dataloader_from_config( config: Union[dict, DictConfig], global_rank: int, @@ -369,7 +391,8 @@ def get_lhotse_dataloader_from_single_config( # reads only light-weight JSON objects; it samples mini-batches and passes # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. dloader_kwargs = dict(dataset=dataset, sampler=sampler) - dloader = torch.utils.data.DataLoader( + dloader = _build_dataloader( + use_stateful_dataloader=config.use_stateful_dataloader, **dloader_kwargs, batch_size=None, num_workers=config.num_workers, @@ -420,6 +443,7 @@ def gather_shared_opts(): "multi_config", "metadata_only", "force_finite", + "use_stateful_dataloader", ] defaults = OmegaConf.structured(LhotseDataLoadingConfig) top_level_config["seed"] = resolve_seed(top_level_config["seed"]) @@ -493,7 +517,8 @@ def gather_shared_opts(): # reads only light-weight JSON objects; it samples mini-batches and passes # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. dloader_kwargs = dict(dataset=dataset, sampler=sampler) - dloader = torch.utils.data.DataLoader( + dloader = _build_dataloader( + use_stateful_dataloader=shared_opts.use_stateful_dataloader, **dloader_kwargs, batch_size=None, num_workers=shared_opts.num_workers, @@ -519,9 +544,6 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Resample as a safeguard; it's a no-op when SR is already OK cuts = cuts.map(partial(resample, sampling_rate=config.sample_rate), apply_fn=None) - # Expands cuts if multiple translations are provided. - cuts = CutSet(LazyFlattener(cuts.map(_flatten_alt_text, apply_fn=None))) - if config.use_multimodal_sampling: assert tokenizer is not None, ( "You must pass a tokenizer to `get_lhotse_dataloader_from_config` in order to" @@ -938,22 +960,6 @@ def _merge_supervisions(cuts: CutSet) -> CutSet: return cuts.merge_supervisions() -def _flatten_alt_text(cut) -> list: - ans = [cut] - if not isinstance(cut, Cut) or cut.custom is None or cut.custom.get("alt_text") is None: - return ans - cut = cut.move_to_memory(audio_format="wav") # performs I/O once and holds audio in memory from now on - # Popping to ease eyesight on debug. - paired_text = cut.custom.pop("alt_text") - for data in paired_text.values(): - # Copy to avoid lazy dataloading issues - data = data.copy() - text_instance = cut.map_supervisions(lambda s: fastcopy(s, text=data["text"], language=data["lang"])) - text_instance.custom = {"text": data.pop("text"), "lang": data.pop("lang"), **data} - ans.append(text_instance) - return ans - - def maybe_set_cuda_expandable_segments(enabled: bool): """ Configures PyTorch memory allocator to expand existing allocated segments diff --git a/nemo/collections/speechlm2/data/datamodule.py b/nemo/collections/speechlm2/data/datamodule.py index 0e95542e4ede..fd5364bdab05 100644 --- a/nemo/collections/speechlm2/data/datamodule.py +++ b/nemo/collections/speechlm2/data/datamodule.py @@ -68,17 +68,34 @@ def __init__(self, cfg, tokenizer: TokenizerSpec, dataset: torch.utils.data.Data getattr(self.cfg, k).force_map_dataset = True self.tokenizer = tokenizer self.dataset = dataset + self._train_dl = None def train_dataloader(self): if "train_ds" not in self.cfg: return None - return get_lhotse_dataloader_from_config( - config=self.cfg.train_ds, - global_rank=self._get_dp_rank(), - world_size=self._get_world_size(), - dataset=FallbackDataset(self.dataset), - tokenizer=self.tokenizer, - ) + if self._train_dl is None: + self._train_dl = get_lhotse_dataloader_from_config( + config=self.cfg.train_ds, + global_rank=self._get_dp_rank(), + world_size=self._get_world_size(), + dataset=FallbackDataset(self.dataset), + tokenizer=self.tokenizer, + ) + return self._train_dl + + def state_dict(self) -> dict: + # Persist the train dataloader state when it's stateful (e.g. torchdata's StatefulDataLoader + # paired with a checkpointable lhotse sampler). This enables exact-batch resume. + if self._train_dl is not None and hasattr(self._train_dl, "state_dict"): + return {"train_dataloader": self._train_dl.state_dict()} + return {} + + def load_state_dict(self, state_dict: dict) -> None: + if "train_dataloader" not in state_dict: + return + dl = self.train_dataloader() + if dl is not None and hasattr(dl, "load_state_dict"): + dl.load_state_dict(state_dict["train_dataloader"]) def val_dataloader(self): if "validation_ds" not in self.cfg: From 48818f5375f6f8b4c0c8e78ee6d16025decab292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 5 May 2026 09:16:02 -0700 Subject: [PATCH 02/30] Support new Lhotse's indexed iterators across NeMo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 23 +- .../common/data/lhotse/indexed_adapters.py | 122 +++ .../common/data/lhotse/nemo_adapters.py | 694 +++++++++++++++--- .../common/data/lhotse/text_adapters.py | 617 +++++++++++++++- .../speechlm2/models/salm_automodel.py | 23 +- 5 files changed, 1379 insertions(+), 100 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 48e69001bccc..84b74804e4fd 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -285,6 +285,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: "force_map_dataset": config.get("force_map_dataset", False), "force_iterable_dataset": config.get("force_iterable_dataset", False), "slice_length": config.get("slice_length", None), + "indexed": config.get("indexed", False), # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa. "reweight_temperature": config.get("reweight_temperature", None), } @@ -348,6 +349,7 @@ def read_txt_jsonl_paths(config: DictConfig) -> tuple[CutSet, bool]: text_field=config.text_field, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, + indexed=config.get("indexed", False), ) ) if not config.get("force_finite", False): @@ -384,6 +386,7 @@ def read_nemo_sft_jsonl(config: DictConfig) -> tuple[CutSet, bool]: language=config.get("language"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, + indexed=config.get("indexed", False), ) ) if not config.get("force_finite", False): @@ -405,6 +408,7 @@ def read_multimodal_conversation_jsonl(config: DictConfig) -> tuple[CutSet, bool system_prompt=config.get("tags", {}).get("system_prompt"), context=config.get("tags", {}).get("context"), slice_length=config.get("slice_length"), + indexed=config.get("indexed", False), ) ) if not config.get("force_finite", False): @@ -426,6 +430,7 @@ def read_share_gpt_as_conversation(config) -> tuple[CutSet, bool]: shuffle_shards=config.shuffle, shard_seed=config.shard_seed, slice_length=config.get("slice_length"), + indexed=config.get("indexed", False), ) ) if not config.get("force_finite", False): @@ -444,6 +449,7 @@ def read_share_gpt_webdataset_as_conversation(config) -> tuple[CutSet, bool]: token_equivalent_duration=config.get("token_equivalent_duration"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, + indexed=config.get("indexed", False), ) ) # When force_finite is False (default), repeat the dataset infinitely so that @@ -751,6 +757,7 @@ def read_parquet_manifest(config: DictConfig) -> tuple[CutSet, bool]: # Extract shuffling options (CRITICAL for distributed training) shuffle_shards = config.get("shuffle", False) shard_seed = config.get("shard_seed", "trng") + indexed = config.get("indexed", False) # 3. Create Iterators for each file iterators = [] @@ -763,6 +770,7 @@ def read_parquet_manifest(config: DictConfig) -> tuple[CutSet, bool]: duration_field=duration_field, lang_field=lang_field, sampling_rate=sampling_rate, + indexed=indexed, ) iterators.append(adapter) @@ -1461,6 +1469,8 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: common_kwargs["shuffle_shards"] = config[key] else: common_kwargs[key] = config[key] + indexed = config.get("indexed", False) + notar_kwargs_extra = {"indexed": indexed} if indexed else {} # The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet # without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse, # so lhotse has to look up the headers of audio files to fill it on-the-fly. @@ -1470,6 +1480,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: metadata_only = config.get("metadata_only", False) force_finite = config.get("force_finite", False) notar_kwargs = {"metadata_only": metadata_only} + tar_kwargs_extra = {"indexed": indexed} if indexed else {} is_tarred = config.get("tarred_audio_filepaths") is not None if isinstance(config.manifest_filepath, (str, Path)): if is_tarred and not metadata_only: @@ -1479,13 +1490,18 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: tar_paths=config.tarred_audio_filepaths, skip_missing_manifest_entries=config.get("skip_missing_manifest_entries", False), slice_length=config.get("slice_length", None), + **tar_kwargs_extra, **common_kwargs, ) ) if not force_finite: cuts = cuts.repeat(preserve_id=True) else: - cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs)) + cuts = CutSet( + LazyNeMoIterator( + config.manifest_filepath, **notar_kwargs, **notar_kwargs_extra, **common_kwargs + ) + ) else: # Format option 1: # Assume it's [[path1], [path2], ...] (same for tarred_audio_filepaths). @@ -1519,10 +1535,13 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: tar_paths=tar_path, skip_missing_manifest_entries=config.get("skip_missing_manifest_entries", False), slice_length=config.get("slice_length", None), + **tar_kwargs_extra, **common_kwargs, ) else: - nemo_iter = LazyNeMoIterator(manifest_path, **notar_kwargs, **common_kwargs) + nemo_iter = LazyNeMoIterator( + manifest_path, **notar_kwargs, **notar_kwargs_extra, **common_kwargs + ) # Then, determine the weight or use one provided if isinstance(manifest_info, str) or len(manifest_info) == 1: weight = len(nemo_iter) diff --git a/nemo/collections/common/data/lhotse/indexed_adapters.py b/nemo/collections/common/data/lhotse/indexed_adapters.py index 831edf0b1f54..edb7e6e86400 100644 --- a/nemo/collections/common/data/lhotse/indexed_adapters.py +++ b/nemo/collections/common/data/lhotse/indexed_adapters.py @@ -245,6 +245,128 @@ def __getitem__(self, idx): return _split_json_audio_pair(name_a, bytes_a, name_b, bytes_b) +class IndexedTarMemberReader: + """ + Random access to a NeMo-style tar archive that stores **one regular member + per sample** (e.g. ``.flac`` per line of an external NeMo manifest). + + Uses the same ``.idx`` format as :class:`IndexedJSONLReader` and + :class:`IndexedTarSampleReader`: little-endian uint64 byte offsets, with + a sentinel equal to the tar file size at the end. Each entry points at + one tar header, and the corresponding payload starts ``512`` bytes later. + + Two access patterns: + + * Positional: ``reader[idx]`` returns ``(member_name, payload_bytes)``. + * Name-keyed: ``reader.get(name)`` returns just the payload bytes. The + name → position map is built lazily on first use by walking the tar + headers (no payload reads), then cached for subsequent calls. + """ + + def __init__( + self, + tar_path: str | Path, + idx_path: str | Path | None = None, + auto_create_index: bool = True, + ): + self.data_path = str(tar_path) + resolved_idx = str(idx_path) if idx_path else self.data_path + ".idx" + if auto_create_index and not os.path.exists(resolved_idx): + create_tar_index(self.data_path, resolved_idx) + self.offsets, self._len = _load_index(self.data_path, resolved_idx) + self._fh = None + self._name_to_idx: dict[str, int] | None = None + + def _ensure_open(self): + if self._fh is None: + self._fh = open(self.data_path, "rb") + + def close(self): + if self._fh is not None: + self._fh.close() + self._fh = None + + def __del__(self): + self.close() + + def __getstate__(self): + s = self.__dict__.copy() + s["_fh"] = None # file handles are not picklable + return s + + def __setstate__(self, state): + self.__dict__.update(state) + + def __len__(self) -> int: + return self._len + + def __getitem__(self, idx: int) -> tuple[str, bytes]: + idx = _resolve_idx(idx, self._len) + offset = int(self.offsets[idx]) + self._ensure_open() + self._fh.seek(offset) + try: + name, data = _read_tar_member(self._fh) + except (EOFError, tarfile.TarError) as e: + raise type(e)( + f"{e} — reading sample {idx}/{self._len} at offset {offset} " + f"in {self.data_path}" + ) from e + return name, data + + def _build_name_index(self) -> dict[str, int]: + """Walk the tar headers once to build a name → sample-index map. + + Reads only the 512-byte tar headers (no payloads), so this is + relatively cheap even on remote storage. Done lazily on first + :meth:`get` call. + + ``tar.add`` writes a PAX extended header (``@PaxHeader``) before any + member with a long path or extended attributes. We skip those and + record the *regular* file's name at each indexed offset. + """ + name_to_idx: dict[str, int] = {} + self._ensure_open() + for i in range(self._len): + offset = int(self.offsets[i]) + self._fh.seek(offset) + while True: + header = self._fh.read(512) + if len(header) < 512 or header == b"\0" * 512: + break + info = tarfile.TarInfo.frombuf( + header, tarfile.ENCODING, "surrogateescape" + ) + if info.type in (tarfile.REGTYPE, tarfile.AREGTYPE): + name_to_idx[info.name] = i + break + # Non-regular (PAX header, GNU long-name, etc.): + # skip its data + 512-byte padding and continue. + size_blocks = (info.size + 511) // 512 * 512 + self._fh.seek(size_blocks, 1) + return name_to_idx + + def get(self, name: str) -> bytes: + """Return the payload bytes of the tar member named ``name``.""" + if self._name_to_idx is None: + self._name_to_idx = self._build_name_index() + try: + idx = self._name_to_idx[name] + except KeyError as e: + raise KeyError( + f"Tar {self.data_path} has no member named '{name}'. " + f"The .idx may be stale or the manifest is referencing a " + f"different tar." + ) from e + _, data = self[idx] + return data + + def __contains__(self, name: str) -> bool: + if self._name_to_idx is None: + self._name_to_idx = self._build_name_index() + return name in self._name_to_idx + + def _read_tar_member(f): """Read the next regular-file tar member, skipping non-regular entries (PAX headers, GNU long-name headers, directory entries, etc.). diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index 69ca3d66c041..e506e8077324 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -13,6 +13,8 @@ # limitations under the License. """Lhotse adapters for NeMo datasets including Parquet support.""" +import bisect +import json import os import random import re @@ -34,7 +36,13 @@ from lhotse.audio.backend import LibsndfileBackend from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed -from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator +from lhotse.lazy import ( + IteratorNode, + LazyIteratorChain, + LazyJsonlIterator, + attach_graph_origin, + normalize_graph_token, +) from lhotse.serialization import open_best from lhotse.utils import compute_num_samples, ifnone @@ -43,7 +51,7 @@ from nemo.utils.data_utils import is_datastore_path -class LazyNeMoIterator: +class LazyNeMoIterator(IteratorNode): """ ``LazyNeMoIterator`` reads a NeMo (non-tarred) JSON manifest and converts it on the fly to an ``Iterable[Cut]``. It's used to create a ``lhotse.CutSet``. @@ -85,6 +93,24 @@ class LazyNeMoIterator: ... "nemo_manifests/train.json", ... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}], ... )) + + Indexed mode (``indexed=True``) + ------------------------------- + + When the underlying manifest is uncompressed JSONL, set ``indexed=True`` to enable + O(1) random access and exact graph-token checkpointing through + :class:`lhotse.indexing.IndexedJsonlReader`. In indexed mode this iterator becomes + an indexed ``IteratorNode`` that can be combined with ``StatefulDataLoader`` for + bit-exact mid-epoch resume. + + Indexed mode requires: + + * the manifest path(s) to use ``.jsonl`` extension and be uncompressed; + * ``extra_fields`` to be unset (lookup-based fields are positional and cannot be + reproduced after a Feistel-permuted random access). + + Sharded indexed inputs are composed via :class:`lhotse.lazy.LazyIteratorChain`, + which picks a Feistel cross-shard permutation for true item-level shuffling. """ def __init__( @@ -96,62 +122,126 @@ def __init__( shuffle_shards: bool = False, shard_seed: int | Literal["randomized", "trng"] = "trng", extra_fields: list[dict[str, str]] | None = None, + indexed: bool = False, ) -> None: self.path = path self.shuffle_shards = shuffle_shards self.shard_seed = shard_seed - paths = expand_sharded_filepaths(path) - - if len(paths) == 1: - self.source = LazyJsonlIterator(paths[0]) - else: - self.source = LazyIteratorChain( - *(LazyJsonlIterator(p) for p in paths), shuffle_iters=self.shuffle_shards, seed=self.shard_seed - ) self.text_field = text_field self.lang_field = lang_field self.metadata_only = metadata_only self.extra_fields = extra_fields + self.indexed = indexed validate_extra_fields(self.extra_fields) + paths = expand_sharded_filepaths(path) + + if indexed: + if extra_fields: + raise ValueError( + "LazyNeMoIterator(indexed=True) does not support 'extra_fields' because " + "their values are positional/streaming and cannot be reconstructed under " + "graph-token random access." + ) + seed = resolve_seed(shard_seed) if shard_seed not in (None, "trng", "randomized") else 0 + indexed_sources = [_LazyIndexedJsonlDictNode(p) for p in paths] + if len(indexed_sources) == 1: + self.source = indexed_sources[0] + else: + self.source = LazyIteratorChain( + *indexed_sources, shuffle_iters=shuffle_shards, seed=seed + ) + else: + if len(paths) == 1: + self.source = LazyJsonlIterator(paths[0]) + else: + self.source = LazyIteratorChain( + *(LazyJsonlIterator(p) for p in paths), + shuffle_iters=self.shuffle_shards, + seed=self.shard_seed, + ) + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed def __iter__(self) -> Generator[Cut, None, None]: seed = resolve_seed(self.shard_seed) # Propagate the random seed extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()] for data in self.source: + graph_token = getattr(data, "_graph_origin", None) if self.indexed else None # filter out entries with valid "_skipme" values. if data.get("_skipme", False): continue - audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path), force_cache=False) - duration = data.pop("duration") - offset = data.pop("offset", None) - cut = self._create_cut( - audio_path=audio_path, offset=offset, duration=duration, sampling_rate=data.pop("sampling_rate", None) - ) - # Note that start=0 and not start=offset because supervision's start if relative to the - # start of the cut; and cut.start is already set to offset - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=cut.duration, - channel=cut.channel, - text=data.get(self.text_field), - language=data.get(self.lang_field), - ) - ) - cut.custom = data + cut = self._build_cut_from_dict(data) for extra_field in extra_fields: extra_field.attach_to(cut) + if graph_token is not None: + attach_graph_origin(cut, graph_token) yield cut + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "LazyNeMoIterator only supports __getitem__ when constructed with indexed=True." + ) + token = normalize_graph_token(token) + data = self.source[token] + cut = self._build_cut_from_dict(data) + return attach_graph_origin(cut, token) + def __len__(self) -> int: return len(self.source) def __add__(self, other): return LazyIteratorChain(self, other) + def state_dict(self) -> dict: + if not self.indexed: + return {} + return {"source": self.source.state_dict()} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + if "source" in sd: + self.source.load_state_dict(sd["source"]) + + def _build_cut_from_dict(self, data: dict) -> Cut: + # Note: ``data`` may be reused across calls in indexed mode (the reader returns + # a fresh dict each time, but we still avoid mutating the inner object). + data = dict(data) + audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path), force_cache=False) + duration = data.pop("duration") + offset = data.pop("offset", None) + cut = self._create_cut( + audio_path=audio_path, + offset=offset, + duration=duration, + sampling_rate=data.pop("sampling_rate", None), + ) + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + channel=cut.channel, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = data + return cut + def _create_cut( self, audio_path: str, @@ -210,7 +300,67 @@ def _create_recording( return Recording.from_file(audio_path) -class LazyNeMoTarredIterator: +class _GraphOriginDict(dict): + """``dict`` subclass that can carry runtime attributes (e.g. ``_graph_origin``).""" + + __slots__ = ("_graph_origin",) + + +class _LazyIndexedJsonlDictNode(IteratorNode): + """ + Internal helper: a graph-restorable indexed JSONL reader that yields raw dicts + (not Cuts). Built on top of :class:`lhotse.indexing.IndexedJsonlReader`. + + Used as the source iterator for :class:`LazyNeMoIterator` (and other adapters) + when ``indexed=True``. Yielded items carry ``_graph_origin`` set to their + integer line index, which allows downstream nodes (e.g. ``LazyIteratorChain``, + ``LazyShuffler``) to compose graph tokens for exact restore. + """ + + is_checkpointable = True + is_indexed = True + has_constant_time_access = True + + def __init__(self, path: str | Path) -> None: + from lhotse.indexing import IndexedJsonlReader + + self.path = path + self._reader = IndexedJsonlReader(path) + self._position = 0 + self._restored = False + + def __getitem__(self, idx): + idx = int(normalize_graph_token(idx)) + item = _GraphOriginDict(self._reader[idx]) + return attach_graph_origin(item, idx) + + def __len__(self) -> int: + return len(self._reader) + + def __iter__(self): + start = self._position if self._restored else 0 + self._restored = False + n = len(self._reader) + for i in range(start, n): + self._position = i + 1 + item = _GraphOriginDict(self._reader[i]) + attach_graph_origin(item, i) + yield item + + def state_dict(self) -> dict: + return {"position": self._position} + + def load_state_dict(self, sd: dict) -> None: + self._position = sd["position"] + self._restored = True + + +# NeMo-tar indexed access is delegated to ``IndexedTarMemberReader`` from +# ``indexed_adapters`` — the same canonical .idx format (uint64 LE offsets + +# sentinel) used everywhere else in NeMo and lhotse for indexed access. + + +class LazyNeMoTarredIterator(IteratorNode): r""" ``LazyNeMoTarredIterator`` reads a NeMo tarred JSON manifest and converts it on the fly to an ``Iterable[Cut]``. It's used to create a ``lhotse.CutSet``. @@ -294,19 +444,27 @@ def __init__( skip_missing_manifest_entries: bool = False, extra_fields: list[dict[str, str]] | None = None, slice_length: int = None, + indexed: bool = False, ) -> None: self.skip_missing_manifest_entries = skip_missing_manifest_entries + self.indexed = indexed self.shard_id_to_manifest: dict[int, Iterable[dict]] self.paths = expand_sharded_filepaths(manifest_path) if len(self.paths) == 1: - logging.warning( - f"You are using Lhotse dataloading for tarred audio with a non-sharded manifest. " - f"This will incur significant memory overhead. To prevent this, please shard file " - f"'{self.paths[0]}' using 'scripts/speech_recognition/convert_to_tarred_audio_dataset.py' " - f"WITHOUT '--no_shard_manifest'" - ) + if not indexed: + logging.warning( + f"You are using Lhotse dataloading for tarred audio with a non-sharded manifest. " + f"This will incur significant memory overhead. To prevent this, please shard file " + f"'{self.paths[0]}' using 'scripts/speech_recognition/convert_to_tarred_audio_dataset.py' " + f"WITHOUT '--no_shard_manifest'" + ) self.source = LazyJsonlIterator(self.paths[0]) - self.shard_id_to_manifest = groupby("shard_id", self.source) + if indexed: + # In indexed mode we will not consume self.source for grouping — the per-shard + # IndexedJsonlReaders below take over, keyed by the position-derived shard_id 0. + self.shard_id_to_manifest = {0: self.source} + else: + self.shard_id_to_manifest = groupby("shard_id", self.source) else: json_pattern = re.compile(r"manifest[^/]*_(\d+)[^/]*\.json") shard_ids = [] @@ -342,6 +500,74 @@ def __init__( self._validate() self.use_ais_get_batch = os.environ.get("USE_AIS_GET_BATCH", "False").lower() == "true" + if indexed: + self._init_indexed() + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def _init_indexed(self) -> None: + """Build per-shard IndexedJsonlReaders + audio-tar index for indexed/random access.""" + from lhotse.indexing import IndexedJsonlReader + + from nemo.collections.common.data.lhotse.indexed_adapters import ( + IndexedTarMemberReader, + ) + + if self.extra_fields: + raise ValueError( + "LazyNeMoTarredIterator(indexed=True) does not support 'extra_fields' " + "because their values are positional and cannot be reproduced under " + "graph-token random access." + ) + if self.slice_length is not None: + raise ValueError( + "LazyNeMoTarredIterator(indexed=True) does not support 'slice_length'." + ) + + # Order shards by their integer shard_id so that global indices are stable. + self._sorted_shard_ids = sorted(self.shard_id_to_tar_path.keys()) + self._cuts_readers: dict[int, IndexedJsonlReader] = {} + # In USE_AIS_GET_BATCH mode we never open the tar files locally — audio is + # fetched lazily via URL/file AudioSource by AudioSamples (typically batched). + self._tar_readers: dict[int, IndexedTarMemberReader] = {} + + # Map shard_id → manifest path (single or multi-file). + if len(self.paths) == 1: + shard_id_to_manifest_path = {sid: self.paths[0] for sid in self._sorted_shard_ids} + else: + json_pattern = re.compile(r"manifest[^/]*_(\d+)[^/]*\.json") + shard_id_to_manifest_path = {} + for p in self.paths: + m = json_pattern.search(p) + assert m is not None + shard_id_to_manifest_path[int(m.group(1))] = p + + cum = 0 + cum_lens = [0] + for sid in self._sorted_shard_ids: + jsonl_path = shard_id_to_manifest_path[sid] + tar_path = self.shard_id_to_tar_path[sid] + self._cuts_readers[sid] = IndexedJsonlReader(jsonl_path) + if not self.use_ais_get_batch: + self._tar_readers[sid] = IndexedTarMemberReader(tar_path) + cum += len(self._cuts_readers[sid]) + cum_lens.append(cum) + self._cum_lens = cum_lens + self._total_len = cum + self._position = 0 + self._restored = False + self._offset_pattern = re.compile(r'^(?P.+)(?P-sub\d+)(?P\.\w+)?$') + def to_shards(self) -> List["LazyNeMoTarredIterator"]: """Convert this iterator to a list of separate iterators for each shard.""" if len(self.paths) == 1: @@ -362,6 +588,13 @@ def to_shards(self) -> List["LazyNeMoTarredIterator"]: ] def _validate(self) -> None: + if self.indexed: + # Indexed mode keys shards by the tar path's shard_id and pairs them with + # the jsonl manifest of the same numeric id (see ``_init_indexed``); the + # streaming-time shard_id consistency check below would otherwise reject + # single-file inputs when the jsonl groups by a different shard_id field. + validate_extra_fields(self.extra_fields) + return shard_ids_tars = set(self.shard_id_to_tar_path) shard_ids_manifest = set(self.shard_id_to_manifest) assert shard_ids_tars == shard_ids_manifest, ( @@ -496,7 +729,184 @@ def _iter_sequential( f"Cannot locate JSON entry for tar file '{tar_info.name}'" ) from e + # ---------------------------------------------------------------------- indexed + def _resolve_global_idx(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._total_len + if idx < 0 or idx >= self._total_len: + raise IndexError( + f"index {idx} out of range for LazyNeMoTarredIterator with {self._total_len} cuts" + ) + shard_pos = bisect.bisect_right(self._cum_lens, idx) - 1 + sid = self._sorted_shard_ids[shard_pos] + return sid, idx - self._cum_lens[shard_pos] + + def _build_indexed_cut(self, data: dict, audio_bytes: bytes, manifest_path: str, tar_path: str) -> Cut | None: + """Decode a single (manifest_entry, audio_bytes) pair into a Cut, mirroring the streaming path.""" + if data.get("_skipme", False): + return None + try: + meta = soundfile.info(BytesIO(audio_bytes)) + except Exception: + logging.warning(f"Skipped corrupted audio member referenced by '{data.get('audio_filepath')}' in {tar_path=}.") + return None + recording = Recording( + id=str(data["audio_filepath"]), + sources=[ + AudioSource(type="memory", channels=list(range(meta.channels)), source=audio_bytes) + ], + sampling_rate=int(meta.samplerate), + num_samples=meta.frames, + duration=meta.duration, + ) + cut = make_cut_with_subset_inmemory_recording( + recording, offset=data.get("offset", 0.0), duration=data.get("duration") + ) + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = _to_custom_attr_dict(data) + cut.manifest_origin = manifest_path + cut.tar_origin = tar_path + return cut + + def _audio_member_name_from_entry(self, entry: dict) -> str: + af = entry["audio_filepath"] + m = self._offset_pattern.match(af) + if m is None: + return af + return m.group("stem") + ifnone(m.group("ext"), "") + + def _build_indexed_url_cut(self, data: dict, manifest_path: str, tar_path: str) -> Cut | None: + """ + AIS GetBatch counterpart of ``_build_indexed_cut``: produces a Cut backed + by a URL/file AudioSource (no audio bytes loaded), so that + ``AudioSamples(use_batch_loader=True)`` can fetch the entire minibatch in + a single AIS GetBatch request. Mirrors the streaming path in + ``_iter_batch_for_ais_get_batch``. + """ + if data.get("_skipme", False): + return None + duration = data.get("duration") + if duration is None: + logging.warning( + f"Skipping '{data.get('audio_filepath')}' - missing duration in manifest" + ) + return None + audio_filename = self._audio_member_name_from_entry(data) + audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}" + # Mirror the streaming path's convention: use type="url" since open_best() + # transparently handles both local paths and remote URLs (ais://, http(s)://, ...). + # AudioSamples' GetBatch loader inspects the URL scheme to dispatch to AIS. + source_type = "url" if "://" in tar_path else "file" + offset = data.get("offset", 0.0) + sampling_rate = data.get("sampling_rate", 16000) + recording = Recording( + id=audio_filename, + sources=[AudioSource(type=source_type, channels=[0], source=audio_url)], + sampling_rate=sampling_rate, + num_samples=compute_num_samples(duration, sampling_rate), + duration=duration, + ) + cut = recording.to_cut() + if offset > 0: + cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) + cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = _to_custom_attr_dict(data) + cut.manifest_origin = manifest_path + cut.tar_origin = tar_path + return cut + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "LazyNeMoTarredIterator only supports __getitem__ when constructed with indexed=True." + ) + idx = int(normalize_graph_token(token)) + sid, local_idx = self._resolve_global_idx(idx) + data = self._cuts_readers[sid][local_idx] + manifest_path = self._cuts_readers[sid].path + tar_path = self.shard_id_to_tar_path[sid] + if self.use_ais_get_batch: + cut = self._build_indexed_url_cut(data, manifest_path, tar_path) + else: + member_name = self._audio_member_name_from_entry(data) + audio_bytes = self._tar_readers[sid].get(member_name) + cut = self._build_indexed_cut(data, audio_bytes, manifest_path, tar_path) + if cut is None: + raise RuntimeError( + f"Cut at global index {idx} (shard {sid}, local {local_idx}) is not decodable; " + f"cannot satisfy random-access __getitem__." + ) + return attach_graph_origin(cut, idx) + + def __len__(self) -> int: + if self.indexed: + return self._total_len + return len(self.source) + + def state_dict(self) -> dict: + if not self.indexed: + return {} + return {"position": self._position, "epoch": self.epoch} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self.epoch = sd.get("epoch", 0) + self._restored = True + + def _iter_indexed(self) -> Generator[Cut, None, None]: + start = self._position if self._restored else 0 + self._restored = False + n = self._total_len + for i in range(start, n): + self._position = i + 1 + sid, local_idx = self._resolve_global_idx(i) + data = self._cuts_readers[sid][local_idx] + manifest_path = self._cuts_readers[sid].path + tar_path = self.shard_id_to_tar_path[sid] + if self.use_ais_get_batch: + cut = self._build_indexed_url_cut(data, manifest_path, tar_path) + else: + member_name = self._audio_member_name_from_entry(data) + try: + audio_bytes = self._tar_readers[sid].get(member_name) + except KeyError: + if self.skip_missing_manifest_entries: + continue + raise + cut = self._build_indexed_cut(data, audio_bytes, manifest_path, tar_path) + if cut is None: + continue + attach_graph_origin(cut, i) + yield cut + self.epoch += 1 + + # ---------------------------------------------------------------- streaming def __iter__(self) -> Generator[Cut, None, None]: + if self.indexed: + yield from self._iter_indexed() + return + shard_ids = self.shard_ids seed = self._get_seed() @@ -579,9 +989,6 @@ def basename(d: dict) -> str: self.epoch += 1 - def __len__(self) -> int: - return len(self.source) - def __add__(self, other): return LazyIteratorChain(self, other) @@ -737,7 +1144,7 @@ def _to_custom_attr_dict(d: dict, _excluded_fields: set[str] = {"duration", "aud return {k: v for k, v in d.items() if k not in _excluded_fields} -class LazyParquetIterator: +class LazyParquetIterator(IteratorNode): """ LazyParquetIterator reads a Parquet file (local or remote) and yields Lhotse Cut objects. It streams data using PyArrow's iter_batches to avoid loading the full file into memory. @@ -749,6 +1156,13 @@ class LazyParquetIterator: duration_field (str): Name of the column containing duration (default: "duration"). lang_field (str): Name of the column containing language (default: "lang"). sampling_rate (int): Fallback sampling rate if not found in metadata (default: 16000). + indexed (bool): When True, enable O(1) random access via row-group lookup + and graph-token checkpointing. Requires the parquet file to expose + row-group statistics (the default for files written by pyarrow/pandas). + + Indexed mode reads one row group at a time on demand and caches the most + recently used row group, so unshuffled or locality-friendly access patterns + avoid repeated decompression. """ def __init__( @@ -759,6 +1173,7 @@ def __init__( duration_field: str = "duration", lang_field: str = "lang", sampling_rate: int = 16000, + indexed: bool = False, ) -> None: # SAFETY CHECK: Ensure pyarrow is actually installed if not HAVE_PYARROW: @@ -772,8 +1187,153 @@ def __init__( self.duration_field = duration_field self.lang_field = lang_field self.sampling_rate = sampling_rate + self.indexed = indexed + self._row_group_offsets: list[int] | None = None + self._cached_row_group_idx: int | None = None + self._cached_row_group: list[dict] | None = None + self._position = 0 + self._restored = False + if indexed: + self._init_indexed() + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def _init_indexed(self) -> None: + try: + parquet_file = pq.ParquetFile(self.path) + except Exception as e: + raise RuntimeError(f"Failed to open Parquet file: {self.path}") from e + offsets = [0] + for i in range(parquet_file.num_row_groups): + offsets.append(offsets[-1] + parquet_file.metadata.row_group(i).num_rows) + self._row_group_offsets = offsets + self._num_row_groups = parquet_file.num_row_groups + self._total_rows = offsets[-1] + del parquet_file # close handle; reopened lazily in workers + + def _load_row_group(self, rg_idx: int) -> list[dict]: + if self._cached_row_group_idx == rg_idx and self._cached_row_group is not None: + return self._cached_row_group + parquet_file = pq.ParquetFile(self.path) + try: + df = parquet_file.read_row_group(rg_idx).to_pandas() + finally: + del parquet_file + rows = df.to_dict("records") + self._cached_row_group_idx = rg_idx + self._cached_row_group = rows + return rows + + def _resolve_row_group(self, idx: int) -> tuple[int, int]: + # Find row group containing global ``idx`` via simple linear/bisect lookup. + offsets = self._row_group_offsets + # Linear scan is fine because num_row_groups is typically small. + for rg_idx in range(self._num_row_groups): + if idx < offsets[rg_idx + 1]: + return rg_idx, idx - offsets[rg_idx] + raise IndexError(f"index {idx} out of range for parquet file with {self._total_rows} rows") + + def _build_cut_from_row(self, row: dict, fallback_idx: int) -> Cut | None: + audio_data = row.get(self.audio_field) + if isinstance(audio_data, dict) and 'bytes' in audio_data: + audio_bytes = audio_data['bytes'] + elif isinstance(audio_data, bytes): + audio_bytes = audio_data + else: + logging.warning( + f"Skipping row {fallback_idx}: Audio column '{self.audio_field}' format unrecognized." + ) + return None + + text = row.get(self.text_field, "") + language = row.get(self.lang_field, None) + row_id = str(row.get('id', f"{Path(self.path).stem}_{fallback_idx}")) + try: + recording = Recording.from_bytes(data=audio_bytes, recording_id=row_id) + except (RuntimeError, ValueError, TypeError) as e: + logging.warning(f"Skipping row {row_id}: Failed to decode audio bytes. {e}") + return None + cut = recording.to_cut() + cut.supervisions.append( + SupervisionSegment( + id=row_id, + recording_id=row_id, + start=0.0, + duration=cut.duration, + channel=0, + text=text, + language=language, + ) + ) + cut.custom = {k: v for k, v in row.items() if k != self.audio_field} + return cut + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "LazyParquetIterator only supports __getitem__ when constructed with indexed=True." + ) + idx = int(normalize_graph_token(token)) + if idx < 0: + idx += self._total_rows + if idx < 0 or idx >= self._total_rows: + raise IndexError(f"index {token} out of range for parquet file with {self._total_rows} rows") + rg_idx, local_idx = self._resolve_row_group(idx) + rows = self._load_row_group(rg_idx) + cut = self._build_cut_from_row(rows[local_idx], fallback_idx=idx) + if cut is None: + raise RuntimeError( + f"Row {idx} in {self.path} is not decodable; cannot satisfy random-access __getitem__." + ) + return attach_graph_origin(cut, idx) + + def __len__(self) -> int: + if self.indexed: + return self._total_rows + raise TypeError("LazyParquetIterator has unknown length unless constructed with indexed=True.") + + def state_dict(self) -> dict: + if not self.indexed: + return {} + return {"position": self._position} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self._restored = True def __iter__(self) -> Generator[Cut, None, None]: + if self.indexed: + yield from self._iter_indexed() + else: + yield from self._iter_streaming() + + def _iter_indexed(self) -> Generator[Cut, None, None]: + start = self._position if self._restored else 0 + self._restored = False + n = self._total_rows + for i in range(start, n): + self._position = i + 1 + rg_idx, local_idx = self._resolve_row_group(i) + rows = self._load_row_group(rg_idx) + cut = self._build_cut_from_row(rows[local_idx], fallback_idx=i) + if cut is None: + continue + attach_graph_origin(cut, i) + yield cut + + def _iter_streaming(self) -> Generator[Cut, None, None]: # Open Parquet file in streaming mode inside __iter__ # This ensures each DataLoader worker gets its own file handle. try: @@ -786,53 +1346,7 @@ def __iter__(self) -> Generator[Cut, None, None]: df = batch.to_pandas() for idx, row in df.iterrows(): - # 1. Extract Audio Bytes - # Handle HuggingFace format: {'bytes': b'...', 'path': '...'} or raw bytes - audio_data = row.get(self.audio_field) - if isinstance(audio_data, dict) and 'bytes' in audio_data: - audio_bytes = audio_data['bytes'] - elif isinstance(audio_data, bytes): - audio_bytes = audio_data - else: - logging.warning(f"Skipping row {idx}: Audio column '{self.audio_field}' format unrecognized.") - continue - - # 2. Extract Metadata - text = row.get(self.text_field, "") - language = row.get(self.lang_field, None) - - # 3. Create Unique ID - # Use 'id' column if exists, else combine filename + index - row_id = str(row.get('id', f"{Path(self.path).stem}_{idx}")) - - # 4. Create Lhotse Recording - try: - recording = Recording.from_bytes( - data=audio_bytes, - recording_id=row_id, - ) - except (RuntimeError, ValueError, TypeError) as e: - logging.warning(f"Skipping row {row_id}: Failed to decode audio bytes. {e}") + cut = self._build_cut_from_row(row, fallback_idx=idx) + if cut is None: continue - - # 5. Create Cut - cut = recording.to_cut() - - # Add Supervision (Transcript) - cut.supervisions.append( - SupervisionSegment( - id=row_id, - recording_id=row_id, - start=0.0, - duration=cut.duration, - channel=0, - text=text, - language=language, - ) - ) - - # Attach any extra metadata from the row to cut.custom - # (Exclude the heavy audio bytes to save RAM) - cut.custom = {k: v for k, v in row.items() if k != self.audio_field} - yield cut diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 8022e9c9e61e..f2d161528bb3 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -33,6 +33,8 @@ from lhotse.shar import AudioTarWriter, JsonlShardWriter from lhotse.utils import Pathlike, compute_num_samples, is_valid_url +from lhotse.lazy import IteratorNode, attach_graph_origin, normalize_graph_token + from nemo.collections.common.data.lhotse.indexed_adapters import ( IndexedJSONLReader, IndexedTarSampleReader, @@ -132,10 +134,13 @@ def __iter__(self) -> Iterator[TextExample]: @dataclass -class LhotseTextJsonlAdapter: +class LhotseTextJsonlAdapter(IteratorNode): """ ``LhotseTextJsonlAdapter`` is used to read a JSONL file and wrap the text field of each line into a ``TextExample``. + + Set ``indexed=True`` to enable O(1) random access plus graph-token + checkpointing (requires uncompressed ``.jsonl`` paths). """ paths: Union[Pathlike, list[Pathlike]] @@ -143,11 +148,97 @@ class LhotseTextJsonlAdapter: text_field: str = "text" shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" + indexed: bool = False def __post_init__(self): self.paths = expand_sharded_filepaths(self.paths) + self._readers: list = [] + self._cum_lens: list[int] = [] + self._position = 0 + self._restored = False + if self.indexed: + from lhotse.indexing import IndexedJsonlReader + + for p in self.paths: + self._readers.append(IndexedJsonlReader(p)) + cum = 0 + self._cum_lens.append(cum) + for r in self._readers: + cum += len(r) + self._cum_lens.append(cum) + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def __len__(self) -> int: + if not self.indexed: + raise TypeError("LhotseTextJsonlAdapter has unknown length unless constructed with indexed=True.") + return self._cum_lens[-1] if self._cum_lens else 0 + + def _resolve(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._cum_lens[-1] + for s in range(len(self._readers)): + if idx < self._cum_lens[s + 1]: + return s, idx - self._cum_lens[s] + raise IndexError(idx) + + def _data_to_example(self, data: dict) -> TextExample | None: + if self.text_field not in data: + return None + return TextExample(data[self.text_field], language=self.language) + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError("LhotseTextJsonlAdapter only supports __getitem__ when indexed=True.") + idx = int(normalize_graph_token(token)) + shard_idx, local_idx = self._resolve(idx) + ex = self._data_to_example(self._readers[shard_idx][local_idx]) + if ex is None: + raise RuntimeError( + f"Index {idx} in {self.paths[shard_idx]} has no '{self.text_field}' field; " + f"cannot satisfy random-access __getitem__." + ) + return attach_graph_origin(ex, idx) + + def state_dict(self) -> dict: + return {"position": self._position} if self.indexed else {} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self._restored = True def __iter__(self) -> Iterator[TextExample]: + if self.indexed: + yield from self._iter_indexed() + else: + yield from self._iter_streaming() + + def _iter_indexed(self) -> Iterator[TextExample]: + start = self._position if self._restored else 0 + self._restored = False + n = self._cum_lens[-1] if self._cum_lens else 0 + for i in range(start, n): + self._position = i + 1 + shard_idx, local_idx = self._resolve(i) + ex = self._data_to_example(self._readers[shard_idx][local_idx]) + if ex is None: + continue + attach_graph_origin(ex, i) + yield ex + + def _iter_streaming(self) -> Iterator[TextExample]: paths = self.paths if self.shuffle_shards: seed = resolve_seed(self.shard_seed) @@ -296,7 +387,7 @@ def default_sft_prompt_format_fn(example: NeMoSFTExample, prompt): @dataclass -class NeMoSFTJsonlAdapter: +class NeMoSFTJsonlAdapter(IteratorNode): """ ``NeMoSFTJsonlAdapter`` is used to read a NeMo LM SFT Chat JSONL file and yield objects of type ``NeMoSFTExample`` that can be sampled with Lhotse. @@ -318,17 +409,94 @@ class NeMoSFTJsonlAdapter: "dataset": str, "category": str, } + + Set ``indexed=True`` to enable O(1) random access plus graph-token + checkpointing (requires uncompressed ``.jsonl`` paths). """ paths: Union[Pathlike, list[Pathlike]] language: str | None = None shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" + indexed: bool = False def __post_init__(self): self.paths = expand_sharded_filepaths(self.paths) + self._readers: list = [] + self._cum_lens: list[int] = [] + self._position = 0 + self._restored = False + if self.indexed: + from lhotse.indexing import IndexedJsonlReader + + for p in self.paths: + self._readers.append(IndexedJsonlReader(p)) + cum = 0 + self._cum_lens.append(cum) + for r in self._readers: + cum += len(r) + self._cum_lens.append(cum) + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def __len__(self) -> int: + if not self.indexed: + raise TypeError("NeMoSFTJsonlAdapter has unknown length unless constructed with indexed=True.") + return self._cum_lens[-1] if self._cum_lens else 0 + + def _resolve(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._cum_lens[-1] + for s in range(len(self._readers)): + if idx < self._cum_lens[s + 1]: + return s, idx - self._cum_lens[s] + raise IndexError(idx) + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError("NeMoSFTJsonlAdapter only supports __getitem__ when indexed=True.") + idx = int(normalize_graph_token(token)) + shard_idx, local_idx = self._resolve(idx) + ex = NeMoSFTExample(self._readers[shard_idx][local_idx], language=self.language) + return attach_graph_origin(ex, idx) + + def state_dict(self) -> dict: + return {"position": self._position} if self.indexed else {} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self._restored = True def __iter__(self) -> Iterator[NeMoSFTExample]: + if self.indexed: + yield from self._iter_indexed() + else: + yield from self._iter_streaming() + + def _iter_indexed(self) -> Iterator[NeMoSFTExample]: + start = self._position if self._restored else 0 + self._restored = False + n = self._cum_lens[-1] if self._cum_lens else 0 + for i in range(start, n): + self._position = i + 1 + shard_idx, local_idx = self._resolve(i) + ex = NeMoSFTExample(self._readers[shard_idx][local_idx], language=self.language) + attach_graph_origin(ex, i) + yield ex + + def _iter_streaming(self) -> Iterator[NeMoSFTExample]: paths = self.paths if self.shuffle_shards: seed = resolve_seed(self.shard_seed) @@ -596,7 +764,7 @@ def _make_url_cut( @dataclass -class NeMoMultimodalConversationJsonlAdapter: +class NeMoMultimodalConversationJsonlAdapter(IteratorNode): """ ``NeMoMultimodalConversationJsonlAdapter`` is used to read a NeMo multimodal conversation JSONL and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse. @@ -615,6 +783,11 @@ class NeMoMultimodalConversationJsonlAdapter: ... ], } + + Set ``indexed=True`` to enable O(1) random access plus graph-token + checkpointing. Indexed mode requires uncompressed JSONL manifests; for the + tarred path it additionally requires uncompressed tar shards (the canonical + ``.idx`` sidecars are built lazily on first construction). """ manifest_filepath: str | list[str] @@ -626,6 +799,7 @@ class NeMoMultimodalConversationJsonlAdapter: system_prompt: str | None = None context: str | None = None slice_length: int | None = None + indexed: bool = False def __post_init__(self): self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath) @@ -635,13 +809,226 @@ def __post_init__(self): self.tarred_audio_filepaths ), f"{len(self.manifest_filepath)} != {len(self.tarred_audio_filepaths)}" self.epoch = 0 + self._cuts_readers: list = [] + self._tar_readers: list = [] + self._cum_lens: list[int] = [] + self._total_len = 0 + self._position = 0 + self._restored = False + if self.indexed: + self._init_indexed() + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def _init_indexed(self) -> None: + from lhotse.indexing import IndexedJsonlReader + + if self.slice_length is not None: + raise ValueError( + "NeMoMultimodalConversationJsonlAdapter(indexed=True) does not support slice_length." + ) + for p in self.manifest_filepath: + self._cuts_readers.append(IndexedJsonlReader(p)) + if self.tarred_audio_filepaths is not None: + from nemo.collections.common.data.lhotse.indexed_adapters import IndexedTarMemberReader + + for p in self.tarred_audio_filepaths: + self._tar_readers.append(IndexedTarMemberReader(p)) + cum = 0 + self._cum_lens.append(cum) + for r in self._cuts_readers: + cum += len(r) + self._cum_lens.append(cum) + self._total_len = cum + + def __len__(self) -> int: + if self.indexed: + return self._total_len + raise TypeError( + "NeMoMultimodalConversationJsonlAdapter has unknown length unless constructed with indexed=True." + ) + + def _resolve(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._total_len + for s in range(len(self._cuts_readers)): + if idx < self._cum_lens[s + 1]: + return s, idx - self._cum_lens[s] + raise IndexError(idx) + + def state_dict(self) -> dict: + return {"position": self._position, "epoch": self.epoch} if self.indexed else {} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self.epoch = sd.get("epoch", 0) + self._restored = True + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "NeMoMultimodalConversationJsonlAdapter only supports __getitem__ when indexed=True." + ) + idx = int(normalize_graph_token(token)) + shard_idx, local_idx = self._resolve(idx) + data = self._cuts_readers[shard_idx][local_idx] + if self._tar_readers: + convo = self._build_conversation_tarred( + data, + tar_reader=self._tar_readers[shard_idx], + tar_path=self.tarred_audio_filepaths[shard_idx], + ) + else: + convo = self._build_conversation_local( + data, manifest_path=self._cuts_readers[shard_idx].path + ) + if convo is None: + raise RuntimeError( + f"Conversation at index {idx} (shard {shard_idx}, local {local_idx}) " + f"could not be built; cannot satisfy random-access __getitem__." + ) + return attach_graph_origin(convo, idx) + + def _build_conversation_local(self, data: dict, manifest_path: str) -> NeMoMultimodalConversation | None: + if self._should_skip(data): + return None + turns = [ + ( + TextTurn( + value=turn["value"], + role=turn["from"].lower(), + ) + if turn["type"] == "text" + else AudioTurn( + cut=( + cut := Recording.from_file(get_full_path(turn["value"], manifest_path)) + .to_cut() + .truncate(offset=turn.get("offset", 0.0), duration=turn.get("duration")) + ).with_id(self._make_cut_id(cut, turn)), + text=cut.supervisions[0].text if cut.supervisions else None, + role=turn["from"].lower(), + audio_locator_tag=self.audio_locator_tag, + ) + ) + for turn in data["conversations"] + ] + if self.context is not None and turns[0].role == "user" and isinstance(turns[0], AudioTurn): + turns = [TextTurn(role="user", value=self.context)] + turns + if self.system_prompt is not None and turns[0].role != "system": + turns = [TextTurn(role="system", value=self.system_prompt)] + turns + return NeMoMultimodalConversation( + id=data["id"], + turns=turns, + token_equivalent_duration=self.token_equivalent_duration, + custom=data.get("custom"), + ) + + def _build_conversation_tarred( + self, data: dict, tar_reader, tar_path: str + ) -> NeMoMultimodalConversation | None: + import io as _io + + import soundfile as _sf + from lhotse import AudioSource as _AudioSource + from lhotse import Recording as _Recording + + if self._should_skip(data): + return None + cuts: list = [] + for turn in data["conversations"]: + if turn["type"] != "audio": + continue + audio_bytes = tar_reader.get(turn["value"]) + try: + meta = _sf.info(_io.BytesIO(audio_bytes)) + except Exception: + logging.warning(f"Skipped corrupted audio member '{turn['value']}' in {tar_path=}.") + return None + recording = _Recording( + id=turn["value"], + sources=[_AudioSource(type="memory", channels=list(range(meta.channels)), source=audio_bytes)], + sampling_rate=int(meta.samplerate), + num_samples=meta.frames, + duration=meta.duration, + ) + cut = recording.to_cut().truncate( + offset=turn.get("offset", 0.0), duration=turn.get("duration") + ) + cut = cut.with_id(self._make_cut_id(cut, turn)) + cuts.append(cut) + cuts = deque(cuts) + turns = [ + ( + TextTurn( + value=turn["value"], + role=turn["from"].lower(), + ) + if turn["type"] == "text" + else AudioTurn( + cut=(c := cuts.popleft()), + text=c.supervisions[0].text if c.supervisions else None, + role=turn["from"].lower(), + audio_locator_tag=self.audio_locator_tag, + ) + ) + for turn in data["conversations"] + ] + if self.context is not None and turns[0].role == "user" and isinstance(turns[0], AudioTurn): + turns = [TextTurn(role="user", value=self.context)] + turns + if self.system_prompt is not None and turns[0].role != "system": + turns = [TextTurn(role="system", value=self.system_prompt)] + turns + return NeMoMultimodalConversation( + id=data["id"], + turns=turns, + token_equivalent_duration=self.token_equivalent_duration, + custom=data.get("custom"), + ) def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + if self.indexed: + yield from self._iter_indexed() + return if self.tarred_audio_filepaths is not None: yield from self._iter_tar() else: yield from self._iter_jsonl() + def _iter_indexed(self) -> Iterator[NeMoMultimodalConversation]: + start = self._position if self._restored else 0 + self._restored = False + n = self._total_len + for i in range(start, n): + self._position = i + 1 + shard_idx, local_idx = self._resolve(i) + data = self._cuts_readers[shard_idx][local_idx] + if self._tar_readers: + convo = self._build_conversation_tarred( + data, + tar_reader=self._tar_readers[shard_idx], + tar_path=self.tarred_audio_filepaths[shard_idx], + ) + else: + convo = self._build_conversation_local( + data, manifest_path=self._cuts_readers[shard_idx].path + ) + if convo is None: + continue + attach_graph_origin(convo, i) + yield convo + self.epoch += 1 + def _should_skip(self, example: dict) -> bool: custom = example.get("custom") if custom is None: @@ -845,7 +1232,7 @@ def _create_sharegpt_turns(audio_locator_tag: str, conversations: list[dict], re @dataclass -class NeMoMultimodalConversationShareGPTJsonlAdapter: +class NeMoMultimodalConversationShareGPTJsonlAdapter(IteratorNode): """ ``NeMoMultimodalConversationShareGPTJsonlAdapter`` is used to read a ShareGPT format multimodal conversation JSONL and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse. @@ -878,6 +1265,7 @@ class NeMoMultimodalConversationShareGPTJsonlAdapter: shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" slice_length: int | None = None + indexed: bool = False def __post_init__(self): self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath) @@ -889,8 +1277,132 @@ def __post_init__(self): self.audio_placeholders = _normalize_audio_placeholders(self.audio_placeholders) self._has_index = all(Path(p + ".idx").exists() for p in self.manifest_filepath) self.epoch = 0 + self._cuts_readers: list = [] + self._tar_readers: list = [] + self._cum_lens: list[int] = [] + self._total_len = 0 + self._position = 0 + self._restored = False + if self.indexed: + self._init_indexed() + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def _init_indexed(self) -> None: + from lhotse.indexing import IndexedJsonlReader + + if self.slice_length is not None: + raise ValueError( + "NeMoMultimodalConversationShareGPTJsonlAdapter(indexed=True) does not support slice_length." + ) + for p in self.manifest_filepath: + self._cuts_readers.append(IndexedJsonlReader(p)) + if self.tarred_audio_filepaths is not None: + from nemo.collections.common.data.lhotse.indexed_adapters import IndexedTarMemberReader + + for p in self.tarred_audio_filepaths: + self._tar_readers.append(IndexedTarMemberReader(p)) + cum = 0 + self._cum_lens.append(cum) + for r in self._cuts_readers: + cum += len(r) + self._cum_lens.append(cum) + self._total_len = cum + + def __len__(self) -> int: + if self.indexed: + return self._total_len + raise TypeError( + "NeMoMultimodalConversationShareGPTJsonlAdapter has unknown length unless constructed with indexed=True." + ) + + def _resolve(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._total_len + for s in range(len(self._cuts_readers)): + if idx < self._cum_lens[s + 1]: + return s, idx - self._cum_lens[s] + raise IndexError(idx) + + def state_dict(self) -> dict: + return {"position": self._position, "epoch": self.epoch} if self.indexed else {} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self.epoch = sd.get("epoch", 0) + self._restored = True + + def _build_one(self, data: dict, shard_idx: int) -> NeMoMultimodalConversation: + conversations = _transform_sharegpt(self.audio_placeholders, data) + if self._tar_readers: + tar_reader = self._tar_readers[shard_idx] + tar_path = self.tarred_audio_filepaths[shard_idx] + return NeMoMultimodalConversation( + id=data.get("id", "missing-example-id"), + turns=_create_sharegpt_turns( + self.audio_locator_tag, + conversations, + lambda t: self._resolve_cut_from_indexed_tar(t, tar_reader, tar_path), + ), + token_equivalent_duration=self.token_equivalent_duration, + ) + manifest_path = self._cuts_readers[shard_idx].path + return NeMoMultimodalConversation( + id=data.get("id", "missing-example-id"), + turns=_create_sharegpt_turns( + self.audio_locator_tag, + conversations, + lambda t, _p=manifest_path: self._resolve_cut_from_path(t, _p), + ), + token_equivalent_duration=self.token_equivalent_duration, + ) + + def _resolve_cut_from_indexed_tar(self, turn, tar_reader, tar_path): + import io as _io + + import soundfile as _sf + from lhotse import AudioSource as _AudioSource + from lhotse import Recording as _Recording + + audio_bytes = tar_reader.get(turn["value"]) + meta = _sf.info(_io.BytesIO(audio_bytes)) + recording = _Recording( + id=turn["value"], + sources=[_AudioSource(type="memory", channels=list(range(meta.channels)), source=audio_bytes)], + sampling_rate=int(meta.samplerate), + num_samples=meta.frames, + duration=meta.duration, + ) + cut = recording.to_cut().truncate(offset=turn.get("offset", 0.0), duration=turn.get("duration")) + return cut.with_id(self._make_cut_id(cut, turn)) + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "NeMoMultimodalConversationShareGPTJsonlAdapter only supports __getitem__ when indexed=True." + ) + idx = int(normalize_graph_token(token)) + shard_idx, local_idx = self._resolve(idx) + data = self._cuts_readers[shard_idx][local_idx] + convo = self._build_one(data, shard_idx) + return attach_graph_origin(convo, idx) def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + if self.indexed: + yield from self._iter_indexed_node() + return if self.tarred_audio_filepaths is not None: yield from self._iter_tar() elif self.shuffle_shards and self._has_index: @@ -898,6 +1410,19 @@ def __iter__(self) -> Iterator[NeMoMultimodalConversation]: else: yield from self._iter_jsonl() + def _iter_indexed_node(self) -> Iterator[NeMoMultimodalConversation]: + start = self._position if self._restored else 0 + self._restored = False + n = self._total_len + for i in range(start, n): + self._position = i + 1 + shard_idx, local_idx = self._resolve(i) + data = self._cuts_readers[shard_idx][local_idx] + convo = self._build_one(data, shard_idx) + attach_graph_origin(convo, i) + yield convo + self.epoch += 1 + def _get_rng(self) -> random.Random: return random.Random(resolve_seed(self.shard_seed) + self.epoch) @@ -1024,7 +1549,7 @@ def _iter_jsonl_indexed(self): @dataclass -class NeMoMultimodalConversationShareGPTWebdatasetAdapter: +class NeMoMultimodalConversationShareGPTWebdatasetAdapter(IteratorNode): """ ``NeMoMultimodalConversationShareGPTWebdatasetAdapter`` reads ShareGPT format multimodal conversations from WebDataset tar archives and yields ``NeMoMultimodalConversation`` objects. @@ -1059,6 +1584,7 @@ class NeMoMultimodalConversationShareGPTWebdatasetAdapter: token_equivalent_duration: float = None shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" + indexed: bool = False def __post_init__(self): import json as _json @@ -1075,13 +1601,94 @@ def __post_init__(self): self.audio_placeholders = _normalize_audio_placeholders(self.audio_placeholders) self._has_index = all(Path(p + ".idx").exists() for p in self._shard_paths) self.epoch = 0 + self._tar_readers: list = [] + self._cum_lens: list[int] = [] + self._total_len = 0 + self._position = 0 + self._restored = False + if self.indexed: + self._init_indexed() + + @property + def is_checkpointable(self) -> bool: + return self.indexed + + @property + def is_indexed(self) -> bool: + return self.indexed + + @property + def has_constant_time_access(self) -> bool: + return self.indexed + + def _init_indexed(self) -> None: + for p in self._shard_paths: + self._tar_readers.append(IndexedTarSampleReader(p)) + cum = 0 + self._cum_lens.append(cum) + for r in self._tar_readers: + cum += len(r) + self._cum_lens.append(cum) + self._total_len = cum + + def __len__(self) -> int: + if self.indexed: + return self._total_len + raise TypeError( + "NeMoMultimodalConversationShareGPTWebdatasetAdapter has unknown length unless constructed with indexed=True." + ) + + def _resolve(self, idx: int) -> tuple[int, int]: + if idx < 0: + idx += self._total_len + for s in range(len(self._tar_readers)): + if idx < self._cum_lens[s + 1]: + return s, idx - self._cum_lens[s] + raise IndexError(idx) + + def state_dict(self) -> dict: + return {"position": self._position, "epoch": self.epoch} if self.indexed else {} + + def load_state_dict(self, sd: dict) -> None: + if not self.indexed: + return + self._position = sd.get("position", 0) + self.epoch = sd.get("epoch", 0) + self._restored = True + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "NeMoMultimodalConversationShareGPTWebdatasetAdapter only supports __getitem__ when indexed=True." + ) + idx = int(normalize_graph_token(token)) + shard_idx, local_idx = self._resolve(idx) + json_data, audio_bytes, audio_name = self._tar_readers[shard_idx][local_idx] + convo = self._yield_from_sample(json_data, audio_bytes, audio_name) + return attach_graph_origin(convo, idx) def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + if self.indexed: + yield from self._iter_indexed_node() + return if self.shuffle_shards and self._has_index: yield from self._iter_indexed() else: yield from self._iter_sequential() + def _iter_indexed_node(self) -> Iterator[NeMoMultimodalConversation]: + start = self._position if self._restored else 0 + self._restored = False + n = self._total_len + for i in range(start, n): + self._position = i + 1 + shard_idx, local_idx = self._resolve(i) + json_data, audio_bytes, audio_name = self._tar_readers[shard_idx][local_idx] + convo = self._yield_from_sample(json_data, audio_bytes, audio_name) + attach_graph_origin(convo, i) + yield convo + self.epoch += 1 + def _get_rng(self) -> random.Random: return random.Random(resolve_seed(self.shard_seed) + self.epoch) diff --git a/nemo/collections/speechlm2/models/salm_automodel.py b/nemo/collections/speechlm2/models/salm_automodel.py index f759ac01bcc7..ac9000404a94 100644 --- a/nemo/collections/speechlm2/models/salm_automodel.py +++ b/nemo/collections/speechlm2/models/salm_automodel.py @@ -234,7 +234,22 @@ def on_fit_start(self) -> None: averaging (see ``_configure_moe_aux_loss_scaler``).""" self._configure_moe_aux_loss_scaler() - def training_step(self, batch: dict, batch_idx: int): + def training_step(self, dataloader_iter): + # Use the explicit ``dataloader_iter`` signature so Lightning selects + # ``_DataLoaderIterDataFetcher`` (no upfront prefetch). With + # ``_PrefetchDataFetcher`` Lightning re-primes one batch from the + # dataloader every time iteration starts (including on resume), which + # advances the StatefulDataLoader past the saved snapshot point and + # breaks bit-identical resumption. The dataloader_iter path consumes + # one batch per training step, so save/restore captures the exact + # next-batch position. + batch, batch_idx, _ = next(dataloader_iter) + # Move to device + apply precision conversions normally done by Lightning + # for the prefetch fetcher path. + batch = self.trainer.precision_plugin.convert_input(batch) + batch = self._on_before_batch_transfer(batch, dataloader_idx=0) + batch = self.trainer.strategy.batch_to_device(batch, dataloader_idx=0) + self._current_batch_idx = batch_idx for m in (self.perception.preprocessor, self.perception.encoder, self.llm): if is_frozen(m): @@ -286,8 +301,10 @@ def training_step(self, batch: dict, batch_idx: int): "target_to_input_ratio": num_frames / (B * T), "padding_ratio": (batch["input_ids"] != self.text_pad_id).long().sum() / batch["input_ids"].numel(), } - self.log("loss", loss_display, on_step=True, prog_bar=True) - self.log_dict({k: v for k, v in ans.items() if k != "loss"}, on_step=True) + # batch_size kwarg is required by Lightning when training_step uses + # the ``dataloader_iter`` signature (it can't auto-infer otherwise). + self.log("loss", loss_display, on_step=True, prog_bar=True, batch_size=B) + self.log_dict({k: v for k, v in ans.items() if k != "loss"}, on_step=True, batch_size=B) self.maybe_log_moe_metrics(batch_idx) return ans From 8a482e437ccfde7fb8222174a2c5a075fac30d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 5 May 2026 09:23:10 -0700 Subject: [PATCH 03/30] refactor read_batch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speechlm2/models/salm_automodel.py | 20 ++------ nemo/core/utils/lightning_utils.py | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 nemo/core/utils/lightning_utils.py diff --git a/nemo/collections/speechlm2/models/salm_automodel.py b/nemo/collections/speechlm2/models/salm_automodel.py index ac9000404a94..8cf8ee7b3eec 100644 --- a/nemo/collections/speechlm2/models/salm_automodel.py +++ b/nemo/collections/speechlm2/models/salm_automodel.py @@ -40,6 +40,7 @@ update_perception_output_dim, ) from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType +from nemo.core.utils.lightning_utils import read_batch class SALMAutomodel(LightningModule, HFHubMixin): @@ -235,21 +236,10 @@ def on_fit_start(self) -> None: self._configure_moe_aux_loss_scaler() def training_step(self, dataloader_iter): - # Use the explicit ``dataloader_iter`` signature so Lightning selects - # ``_DataLoaderIterDataFetcher`` (no upfront prefetch). With - # ``_PrefetchDataFetcher`` Lightning re-primes one batch from the - # dataloader every time iteration starts (including on resume), which - # advances the StatefulDataLoader past the saved snapshot point and - # breaks bit-identical resumption. The dataloader_iter path consumes - # one batch per training step, so save/restore captures the exact - # next-batch position. - batch, batch_idx, _ = next(dataloader_iter) - # Move to device + apply precision conversions normally done by Lightning - # for the prefetch fetcher path. - batch = self.trainer.precision_plugin.convert_input(batch) - batch = self._on_before_batch_transfer(batch, dataloader_idx=0) - batch = self.trainer.strategy.batch_to_device(batch, dataloader_idx=0) - + # ``dataloader_iter`` signature → Lightning selects + # ``_DataLoaderIterDataFetcher`` (no prefetch) which is required for + # bit-identical checkpoint resumption. See ``read_batch`` docstring. + batch, batch_idx = read_batch(dataloader_iter, self) self._current_batch_idx = batch_idx for m in (self.perception.preprocessor, self.perception.encoder, self.llm): if is_frozen(m): diff --git a/nemo/core/utils/lightning_utils.py b/nemo/core/utils/lightning_utils.py new file mode 100644 index 000000000000..77c88942ac9f --- /dev/null +++ b/nemo/core/utils/lightning_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers for working with PyTorch Lightning's ``training_step``.""" +from typing import Any, Iterator, Tuple + +import lightning.pytorch as pl + + +def read_batch(dataloader_iter: Iterator, model: pl.LightningModule) -> Tuple[Any, int]: + """Pull the next batch from a Lightning ``dataloader_iter`` and apply the + device/precision conversions that ``_PrefetchDataFetcher`` would have + applied for the default ``training_step(batch, batch_idx)`` signature. + + Use this from a ``training_step(self, dataloader_iter)``-style step. That + signature makes Lightning select ``_DataLoaderIterDataFetcher`` (no + prefetch), which is required for bit-identical checkpoint resumption with + a stateful dataloader: the default ``_PrefetchDataFetcher`` re-primes one + batch on every iter init (including on resume), advancing the stateful + dataloader past the saved snapshot point and giving the resumed run a + one-batch drift versus the continuous run. + + Args: + dataloader_iter: The iterator passed by Lightning into a + ``training_step(self, dataloader_iter)`` (an instance of + ``_DataFetcherWrapper``). Yields ``(batch, batch_idx, dataloader_idx)``. + model: The ``LightningModule`` whose ``trainer`` carries the precision + plugin and strategy used to move the batch to device. + + Returns: + ``(batch, batch_idx)`` — batch is already converted to the right + precision and moved to the model's device, ready for forward. + """ + batch, batch_idx, dataloader_idx = next(dataloader_iter) + trainer = model.trainer + batch = trainer.precision_plugin.convert_input(batch) + batch = model._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) + batch = trainer.strategy.batch_to_device(batch, dataloader_idx=dataloader_idx) + return batch, batch_idx From 086f0e3491e263a5f9b2a5c9d5f5dcded644b7de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 5 May 2026 16:53:32 -0700 Subject: [PATCH 04/30] refactor/cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/indexed_adapters.py | 49 +++-- .../common/data/lhotse/nemo_adapters.py | 205 ++++++------------ 2 files changed, 100 insertions(+), 154 deletions(-) diff --git a/nemo/collections/common/data/lhotse/indexed_adapters.py b/nemo/collections/common/data/lhotse/indexed_adapters.py index edb7e6e86400..597e6c1f4726 100644 --- a/nemo/collections/common/data/lhotse/indexed_adapters.py +++ b/nemo/collections/common/data/lhotse/indexed_adapters.py @@ -24,6 +24,10 @@ # Knuth's multiplicative hash constant (golden-ratio derived, 32-bit). _KNUTH_HASH = 2654435761 +# Tar block size + the all-zeros block that marks end-of-archive in tar. +_TAR_BLOCK_SIZE = 512 +_TAR_ZERO_BLOCK = b'\0' * _TAR_BLOCK_SIZE + class LazyShuffledRange: """ @@ -184,8 +188,8 @@ def _validate_index(self): last = int(self.offsets[self._len - 1]) with open(self.data_path, 'rb') as f: f.seek(last) - buf = f.read(512) - if len(buf) < 512 or buf == b'\0' * 512: + buf = f.read(_TAR_BLOCK_SIZE) + if len(buf) < _TAR_BLOCK_SIZE or buf == _TAR_ZERO_BLOCK: self._len -= 1 else: break @@ -193,13 +197,13 @@ def _validate_index(self): def _check_offset_is_tar_header(self, offset: int, label: str = ""): with open(self.data_path, 'rb') as f: f.seek(offset) - buf = f.read(512) - if len(buf) < 512: + buf = f.read(_TAR_BLOCK_SIZE) + if len(buf) < _TAR_BLOCK_SIZE: raise ValueError( f"Tar index for {self.data_path}: {label} offset {offset} " f"is too close to EOF (file size {self._data_size})." ) - if buf == b'\0' * 512: + if buf == _TAR_ZERO_BLOCK: raise ValueError( f"Tar index for {self.data_path}: {label} offset {offset} " f"points to a zero block (end-of-archive marker), not a tar header. " @@ -328,11 +332,10 @@ def _build_name_index(self) -> dict[str, int]: name_to_idx: dict[str, int] = {} self._ensure_open() for i in range(self._len): - offset = int(self.offsets[i]) - self._fh.seek(offset) + self._fh.seek(int(self.offsets[i])) while True: - header = self._fh.read(512) - if len(header) < 512 or header == b"\0" * 512: + header = self._fh.read(_TAR_BLOCK_SIZE) + if len(header) < _TAR_BLOCK_SIZE or header == _TAR_ZERO_BLOCK: break info = tarfile.TarInfo.frombuf( header, tarfile.ENCODING, "surrogateescape" @@ -340,9 +343,8 @@ def _build_name_index(self) -> dict[str, int]: if info.type in (tarfile.REGTYPE, tarfile.AREGTYPE): name_to_idx[info.name] = i break - # Non-regular (PAX header, GNU long-name, etc.): - # skip its data + 512-byte padding and continue. - size_blocks = (info.size + 511) // 512 * 512 + # Skip non-regular member (PAX/GNU long-name) data + padding. + size_blocks = -(-info.size // _TAR_BLOCK_SIZE) * _TAR_BLOCK_SIZE self._fh.seek(size_blocks, 1) return name_to_idx @@ -378,16 +380,16 @@ def _read_tar_member(f): arbitrary byte offset and read just the members we need in O(1). """ while True: - header_buf = f.read(512) - if len(header_buf) < 512 or header_buf == b'\0' * 512: + header_buf = f.read(_TAR_BLOCK_SIZE) + if len(header_buf) < _TAR_BLOCK_SIZE or header_buf == _TAR_ZERO_BLOCK: raise EOFError("End of tar archive or unexpected EOF") info = tarfile.TarInfo.frombuf(header_buf, tarfile.ENCODING, "surrogateescape") data = f.read(info.size) if len(data) < info.size: raise EOFError("Unexpected end of tar file while reading data") - remainder = info.size % 512 + remainder = info.size % _TAR_BLOCK_SIZE if remainder: - f.seek(512 - remainder, 1) + f.seek(_TAR_BLOCK_SIZE - remainder, 1) if info.type not in (tarfile.REGTYPE, tarfile.AREGTYPE): continue return info.name, data @@ -399,10 +401,14 @@ def create_index(jsonl_path, idx_path): Format: sequence of little-endian uint64 values ``[Offset_0, Offset_1, ..., Offset_N, File_Size]`` + + Written atomically (tmp + ``os.replace``) so concurrent writers can't + observe a half-written ``.idx``. """ # Flush the write buffer every 8 MiB to limit memory usage on large files. flush_threshold = 8 * 1024 * 1024 - with open(jsonl_path, 'rb') as f_in, open(idx_path, 'wb') as f_out: + tmp_path = f"{idx_path}.tmp.{os.getpid()}" + with open(jsonl_path, 'rb') as f_in, open(tmp_path, 'wb') as f_out: current_offset = 0 write_buffer = bytearray() write_buffer.extend(struct.pack('.+)(?P-sub\d+)(?P\.\w+)?$') + class LazyNeMoIterator(IteratorNode): """ @@ -143,7 +150,9 @@ def __init__( "graph-token random access." ) seed = resolve_seed(shard_seed) if shard_seed not in (None, "trng", "randomized") else 0 - indexed_sources = [_LazyIndexedJsonlDictNode(p) for p in paths] + indexed_sources = [ + LazyIndexedManifestIterator(p, decode=GraphOriginDict) for p in paths + ] if len(indexed_sources) == 1: self.source = indexed_sources[0] else: @@ -300,66 +309,6 @@ def _create_recording( return Recording.from_file(audio_path) -class _GraphOriginDict(dict): - """``dict`` subclass that can carry runtime attributes (e.g. ``_graph_origin``).""" - - __slots__ = ("_graph_origin",) - - -class _LazyIndexedJsonlDictNode(IteratorNode): - """ - Internal helper: a graph-restorable indexed JSONL reader that yields raw dicts - (not Cuts). Built on top of :class:`lhotse.indexing.IndexedJsonlReader`. - - Used as the source iterator for :class:`LazyNeMoIterator` (and other adapters) - when ``indexed=True``. Yielded items carry ``_graph_origin`` set to their - integer line index, which allows downstream nodes (e.g. ``LazyIteratorChain``, - ``LazyShuffler``) to compose graph tokens for exact restore. - """ - - is_checkpointable = True - is_indexed = True - has_constant_time_access = True - - def __init__(self, path: str | Path) -> None: - from lhotse.indexing import IndexedJsonlReader - - self.path = path - self._reader = IndexedJsonlReader(path) - self._position = 0 - self._restored = False - - def __getitem__(self, idx): - idx = int(normalize_graph_token(idx)) - item = _GraphOriginDict(self._reader[idx]) - return attach_graph_origin(item, idx) - - def __len__(self) -> int: - return len(self._reader) - - def __iter__(self): - start = self._position if self._restored else 0 - self._restored = False - n = len(self._reader) - for i in range(start, n): - self._position = i + 1 - item = _GraphOriginDict(self._reader[i]) - attach_graph_origin(item, i) - yield item - - def state_dict(self) -> dict: - return {"position": self._position} - - def load_state_dict(self, sd: dict) -> None: - self._position = sd["position"] - self._restored = True - - -# NeMo-tar indexed access is delegated to ``IndexedTarMemberReader`` from -# ``indexed_adapters`` — the same canonical .idx format (uint64 LE offsets + -# sentinel) used everywhere else in NeMo and lhotse for indexed access. - - class LazyNeMoTarredIterator(IteratorNode): r""" ``LazyNeMoTarredIterator`` reads a NeMo tarred JSON manifest and converts it on the fly to an ``Iterable[Cut]``. @@ -566,7 +515,6 @@ def _init_indexed(self) -> None: self._total_len = cum self._position = 0 self._restored = False - self._offset_pattern = re.compile(r'^(?P.+)(?P-sub\d+)(?P\.\w+)?$') def to_shards(self) -> List["LazyNeMoTarredIterator"]: """Convert this iterator to a list of separate iterators for each shard.""" @@ -741,6 +689,31 @@ def _resolve_global_idx(self, idx: int) -> tuple[int, int]: sid = self._sorted_shard_ids[shard_pos] return sid, idx - self._cum_lens[shard_pos] + def _audio_member_name_from_entry(self, entry: dict) -> str: + af = entry["audio_filepath"] + m = _OFFSET_PATTERN.match(af) + if m is None: + return af + return m.group("stem") + ifnone(m.group("ext"), "") + + def _attach_supervision_and_metadata( + self, cut: Cut, data: dict, manifest_path: str, tar_path: str + ) -> Cut: + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = _to_custom_attr_dict(data) + cut.manifest_origin = manifest_path + cut.tar_origin = tar_path + return cut + def _build_indexed_cut(self, data: dict, audio_bytes: bytes, manifest_path: str, tar_path: str) -> Cut | None: """Decode a single (manifest_entry, audio_bytes) pair into a Cut, mirroring the streaming path.""" if data.get("_skipme", False): @@ -762,35 +735,14 @@ def _build_indexed_cut(self, data: dict, audio_bytes: bytes, manifest_path: str, cut = make_cut_with_subset_inmemory_recording( recording, offset=data.get("offset", 0.0), duration=data.get("duration") ) - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=cut.duration, - text=data.get(self.text_field), - language=data.get(self.lang_field), - ) - ) - cut.custom = _to_custom_attr_dict(data) - cut.manifest_origin = manifest_path - cut.tar_origin = tar_path - return cut - - def _audio_member_name_from_entry(self, entry: dict) -> str: - af = entry["audio_filepath"] - m = self._offset_pattern.match(af) - if m is None: - return af - return m.group("stem") + ifnone(m.group("ext"), "") + return self._attach_supervision_and_metadata(cut, data, manifest_path, tar_path) def _build_indexed_url_cut(self, data: dict, manifest_path: str, tar_path: str) -> Cut | None: """ AIS GetBatch counterpart of ``_build_indexed_cut``: produces a Cut backed by a URL/file AudioSource (no audio bytes loaded), so that ``AudioSamples(use_batch_loader=True)`` can fetch the entire minibatch in - a single AIS GetBatch request. Mirrors the streaming path in - ``_iter_batch_for_ais_get_batch``. + a single AIS GetBatch request. Mirrors ``_iter_batch_for_ais_get_batch``. """ if data.get("_skipme", False): return None @@ -802,9 +754,8 @@ def _build_indexed_url_cut(self, data: dict, manifest_path: str, tar_path: str) return None audio_filename = self._audio_member_name_from_entry(data) audio_url = f"{tar_path.rstrip('/')}/{audio_filename.lstrip('/')}" - # Mirror the streaming path's convention: use type="url" since open_best() - # transparently handles both local paths and remote URLs (ais://, http(s)://, ...). - # AudioSamples' GetBatch loader inspects the URL scheme to dispatch to AIS. + # ``open_best`` handles ais://, http(s)://, and local paths uniformly; + # the AIS GetBatch loader still keys off the URL scheme. source_type = "url" if "://" in tar_path else "file" offset = data.get("offset", 0.0) sampling_rate = data.get("sampling_rate", 16000) @@ -819,41 +770,40 @@ def _build_indexed_url_cut(self, data: dict, manifest_path: str, tar_path: str) if offset > 0: cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) cut.id = f"{cut.id}-{round(offset * 1e2):06d}-{round(duration * 1e2):06d}" - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=cut.duration, - text=data.get(self.text_field), - language=data.get(self.lang_field), - ) - ) - cut.custom = _to_custom_attr_dict(data) - cut.manifest_origin = manifest_path - cut.tar_origin = tar_path - return cut + return self._attach_supervision_and_metadata(cut, data, manifest_path, tar_path) - def __getitem__(self, token): - if not self.indexed: - raise NotImplementedError( - "LazyNeMoTarredIterator only supports __getitem__ when constructed with indexed=True." - ) - idx = int(normalize_graph_token(token)) + def _decode_cut_at(self, idx: int) -> Cut | None: + """Build the Cut for a global index in indexed mode (AIS or local). + + Returns ``None`` if the audio member is missing and + ``skip_missing_manifest_entries`` is set, or if the entry has + ``_skipme=True`` / undecodable audio. + """ sid, local_idx = self._resolve_global_idx(idx) data = self._cuts_readers[sid][local_idx] manifest_path = self._cuts_readers[sid].path tar_path = self.shard_id_to_tar_path[sid] if self.use_ais_get_batch: - cut = self._build_indexed_url_cut(data, manifest_path, tar_path) - else: - member_name = self._audio_member_name_from_entry(data) + return self._build_indexed_url_cut(data, manifest_path, tar_path) + member_name = self._audio_member_name_from_entry(data) + try: audio_bytes = self._tar_readers[sid].get(member_name) - cut = self._build_indexed_cut(data, audio_bytes, manifest_path, tar_path) + except KeyError: + if self.skip_missing_manifest_entries: + return None + raise + return self._build_indexed_cut(data, audio_bytes, manifest_path, tar_path) + + def __getitem__(self, token): + if not self.indexed: + raise NotImplementedError( + "LazyNeMoTarredIterator only supports __getitem__ when constructed with indexed=True." + ) + idx = int(normalize_graph_token(token)) + cut = self._decode_cut_at(idx) if cut is None: raise RuntimeError( - f"Cut at global index {idx} (shard {sid}, local {local_idx}) is not decodable; " - f"cannot satisfy random-access __getitem__." + f"Cut at global index {idx} is not decodable; cannot satisfy random-access __getitem__." ) return attach_graph_origin(cut, idx) @@ -877,24 +827,9 @@ def load_state_dict(self, sd: dict) -> None: def _iter_indexed(self) -> Generator[Cut, None, None]: start = self._position if self._restored else 0 self._restored = False - n = self._total_len - for i in range(start, n): + for i in range(start, self._total_len): self._position = i + 1 - sid, local_idx = self._resolve_global_idx(i) - data = self._cuts_readers[sid][local_idx] - manifest_path = self._cuts_readers[sid].path - tar_path = self.shard_id_to_tar_path[sid] - if self.use_ais_get_batch: - cut = self._build_indexed_url_cut(data, manifest_path, tar_path) - else: - member_name = self._audio_member_name_from_entry(data) - try: - audio_bytes = self._tar_readers[sid].get(member_name) - except KeyError: - if self.skip_missing_manifest_entries: - continue - raise - cut = self._build_indexed_cut(data, audio_bytes, manifest_path, tar_path) + cut = self._decode_cut_at(i) if cut is None: continue attach_graph_origin(cut, i) @@ -917,17 +852,15 @@ def __iter__(self) -> Generator[Cut, None, None]: # Propagate the random seed extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()] - # Handle NeMo tarred manifests with offsets. - # They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset. - offset_pattern = re.compile(r'^(?P.+)(?P-sub\d+)(?P\.\w+)?$') - + # NeMo tarred manifests can have multiple JSONL entries pointing at the + # same audio member with -subN audio_filepath suffixes (per-offset cuts). for sid in shard_ids: manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0] def basename(d: dict) -> str: return ( m.group("stem") + ifnone(m.group("ext"), "") - if (m := offset_pattern.match(k := d["audio_filepath"])) is not None + if (m := _OFFSET_PATTERN.match(k := d["audio_filepath"])) is not None else k ) From ad6861a02fe075dc27d1bced2353efd8b0cfd048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 6 May 2026 16:06:42 -0700 Subject: [PATCH 05/30] Documentation update to reflect indexed/checkpointable things + general gap coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- docs/source/asr/datasets.rst | 6 + docs/source/audio/datasets.rst | 6 + docs/source/dataloaders.rst | 797 ++++++++++++++++++++++++++- docs/source/speechlm2/datasets.rst | 64 +++ scripts/dataloading/build_indexes.py | 377 +++++++++++++ 5 files changed, 1229 insertions(+), 21 deletions(-) create mode 100644 scripts/dataloading/build_indexes.py diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 09ff87ea180c..620194c53727 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -3,6 +3,12 @@ Datasets NeMo ASR models expect data as a set of audio files plus a manifest file describing each utterance. +.. seealso:: + + For Lhotse-based dataloading (the recommended path for new ASR + recipes — dynamic bucketing, multi-source mixing, indexed/resumable + dataloading), see :doc:`/dataloaders`. + .. _section-with-manifest-format-explanation: Manifest Format diff --git a/docs/source/audio/datasets.rst b/docs/source/audio/datasets.rst index 4c023961a29e..781b0a9e99d8 100644 --- a/docs/source/audio/datasets.rst +++ b/docs/source/audio/datasets.rst @@ -3,6 +3,12 @@ Datasets The `audio` collection expect the training, validation and tests datasets in either NeMo format or Lhotse format. +.. seealso:: + + For the Lhotse dataloader's full surface — supported ``input_cfg`` + types, bucketing, indexed manifests + resumable dataloading, and the + ``LhotseDataLoadingConfig`` field reference — see :doc:`/dataloaders`. + NeMo Format ----------- diff --git a/docs/source/dataloaders.rst b/docs/source/dataloaders.rst index 8a7ed848b8a8..7ef1ffbc761d 100644 --- a/docs/source/dataloaders.rst +++ b/docs/source/dataloaders.rst @@ -24,26 +24,6 @@ NeMo supports using `Lhotse`_, a speech data handling library, as a dataloading constant in time (i.e., stationary); in fact, each mini-batch will have roughly the same ratio of data coming from each source. Since the multiplexing is done dynamically, it is very easy to tune the sampling weights. -Lhotse dataloading supports the following types of inputs: - -* NeMo manifests - Regular NeMo JSON manifests. -* NeMo tarred data - Tarred NeMo JSON manifests + audio tar files; we also support combination of multiple NeMo - tarred data sources (e.g., multiple buckets of NeMo data or multiple datasets) via dynamic multiplexing. - - We support using a subset of Tarred NeMo JSON manifests along with audio tar files without disrupting the alignment between the tarred files and their corresponding manifests. - This feature is essential because large datasets often consist of numerous tar files and multiple versions of Tarred NeMo JSON manifest subsets, which may contain only a portion of the audio files due to filtering for various reasons. - To skip specific entries in the manifests without repeatedly copying and retarring audio files, the entries must include a ``_skipme`` key. This key should be set to ``True``, ``1``, or a reason for skipping (e.g., ``low character-rate``). - -* Lhotse CutSet manifests - Regular Lhotse CutSet manifests (typically gzipped JSONL). - See `Lhotse Cuts documentation`_ to learn more about Lhotse data formats. -* Lhotse Shar data - Lhotse Shar is a data format that also uses tar files for sequential data loading, - but is designed to be modular (i.e., easily extensible with new data sources and with new feature fields). - More details can be found here: |tutorial_shar| - .. caution:: As of now, Lhotse is mainly supported in most ASR model configurations. We aim to gradually extend this support to other speech tasks. .. _Lhotse: https://github.com/lhotse-speech/lhotse @@ -51,6 +31,269 @@ Lhotse dataloading supports the following types of inputs: .. |tutorial_shar| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb +Architecture overview +--------------------- + +The Lhotse dataloader is a pipeline of small components. Each YAML option you +set lands in exactly one of them, so it pays to know which is which:: + + input_cfg entry ──► parser_fn ──► Adapter (IteratorNode) + (registered │ + via @data_type_parser) ▼ + CutSet (lazy iterator graph) + │ + SamplingConstraint ──► CutSampler + │ + ▼ + IterableDatasetWrapper + │ + ▼ + user-defined Dataset + │ + ▼ + DataLoader + (or StatefulDataLoader) + +Components, top to bottom: + +* **input_cfg entry** — one YAML dict identified by ``type:`` (e.g. + ``type: nemo_tarred``). Listed below in :ref:`lhotse-format-reference`. +* **parser_fn** — registered with the ``@data_type_parser`` decorator in + ``nemo/collections/common/data/lhotse/cutset.py``. Reads the entry and + returns ``(CutSet, is_tarred)``. Users can add their own (see + :ref:`lhotse-extension-hooks`). +* **Adapter** — a class that knows how to iterate one specific on-disk + format (e.g. ``LazyNeMoTarredIterator``, ``LazyParquetIterator``, + ``NeMoMultimodalConversationJsonlAdapter``). All recent adapters are + Lhotse :class:`~lhotse.lazy.IteratorNode` subclasses and support + ``indexed=True`` for O(1) random access — see + :ref:`indexed-resumable-dataloading`. +* **CutSet** — Lhotse's lazy manifest wrapper. Composing multiple sources + produces a graph of iterator nodes (mux, mix, map, filter, …) underneath. +* **SamplingConstraint** — defines what "length" means for batch packing: + :class:`~lhotse.dataset.sampling.base.TimeConstraint` (audio duration, + default), :class:`~lhotse.dataset.sampling.base.TokenConstraint` (token + count, multimodal), ``MultimodalSamplingConstraint`` / + ``FixedBucketBatchSizeConstraint2D`` (NeMo extensions; see + :ref:`lhotse-sampling-constraints`). +* **CutSampler** — :class:`~lhotse.dataset.sampling.DynamicCutSampler` or + :class:`~lhotse.dataset.sampling.DynamicBucketingSampler`, picked + automatically based on ``use_bucketing``. +* **IterableDatasetWrapper** — Lhotse helper that turns the sampler-produced + ``CutSet`` mini-batches into a stream the PyTorch ``DataLoader`` can + consume. +* **Dataset class** — supplied by the model code; converts a ``CutSet`` + mini-batch into a ``dict[str, Tensor]``. The same dataset class can serve + multiple model architectures because all batching is upstream. + +.. _lhotse-format-reference: + +Supported input formats +----------------------- + +Every entry in ``input_cfg`` is identified by ``type:``. The table below is +the canonical list of every type the dataloader understands today, what it +returns, and the on-disk shape it expects. + +.. list-table:: + :header-rows: 1 + :widths: 18 32 14 8 8 10 10 + + * - ``type:`` + - Purpose + - Yields + - Audio + - Tarred + - Indexable + - Adapter / parser + * - ``nemo`` + - NeMo non-tarred JSON manifest (per-file audio) + - ``Cut`` + - yes + - no + - yes + - ``LazyNeMoIterator`` + * - ``nemo_tarred`` + - NeMo tarred manifest + audio tar shards + - ``Cut`` + - yes + - yes + - yes + - ``LazyNeMoTarredIterator`` + * - ``lhotse`` + - Plain Lhotse cuts JSONL + - ``Cut`` + - yes + - no + - yes + - lhotse ``LazyJsonlIterator`` / ``LazyIndexedManifestIterator`` + * - ``lhotse_shar`` + - Lhotse Shar (sharded archive directory) + - ``Cut`` + - yes + - yes + - yes + - lhotse ``LazySharIterator`` + * - ``parquet`` + - Parquet file with audio bytes column + - ``Cut`` + - yes + - no + - yes (row groups) + - ``LazyParquetIterator`` + * - ``txt`` + - One example per line, raw text + - ``TextExample`` + - no + - n/a + - no + - ``LhotseTextAdapter`` + * - ``txt_jsonl`` + - One JSON object per line; configurable text field + - ``TextExample`` + - no + - n/a + - yes + - ``LhotseTextJsonlAdapter`` + * - ``txt_pair`` + - Source + target text files for translation + - ``SourceTargetTextExample`` + - no + - n/a + - no + - ``LhotseTextPairAdapter`` + * - ``multimodal_conversation`` + - Multi-turn chat with mixed text/audio turns (JSONL) + - ``NeMoMultimodalConversation`` + - optional + - optional + - yes + - ``NeMoMultimodalConversationJsonlAdapter`` + * - ``share_gpt`` + - ShareGPT-format JSONL → conversation + - ``NeMoMultimodalConversation`` + - optional + - optional + - yes + - ``NeMoMultimodalConversationShareGPTJsonlAdapter`` + * - ``share_gpt_webdataset`` + - ShareGPT in WebDataset tar shards + - ``NeMoMultimodalConversation`` + - optional + - yes + - yes + - ``NeMoMultimodalConversationShareGPTWebdatasetAdapter`` + * - ``lhotse_as_conversation`` + - Read ASR data and emit it as ASR conversation + - ``NeMoMultimodalConversation`` + - yes + - inherits + - inherits + - transform on ``read_cutset_from_config`` + * - ``sqa_as_conversation`` + - Spoken-QA → 3-turn conversation (question / audio / answer) + - ``NeMoMultimodalConversation`` + - yes + - inherits + - inherits + - transform + * - ``s2s_as_conversation`` + - Duplex S2S → conversation + - ``NeMoMultimodalConversation`` + - yes + - inherits + - inherits + - transform + * - ``s2s_duplex_overlap_as_s2s_duplex`` + - Overlapping agent/user segments → unified S2S timeline + - ``Cut`` + - yes + - inherits + - inherits + - transform + * - ``s2s_duplex_reverse_role`` + - Swap user and agent in a duplex cut + - ``Cut`` + - yes + - inherits + - inherits + - transform + * - ``lhotse_magpietts_data_as_continuation`` + - MagpieTTS dataset → S2S duplex continuation + - ``Cut`` + - yes + - inherits + - inherits + - transform + * - ``nemo_tarred_to_duplex`` + - Single-supervision NeMo → duplex (user speech + agent silence) + - ``Cut`` + - yes + - yes + - inherits + - transform + * - ``multi_speaker_simulator`` + - Synthetic multi-speaker mixtures from a manifest + - ``Cut`` + - yes + - n/a + - no + - ``MultiSpeakerMixtureGenerator`` + * - ``group`` + - Wrap a list of entries with a shared ``weight`` and ``tags`` + - (nested) + - n/a + - n/a + - n/a + - n/a + +Notes: + +* "Inherits" means the type is a transform that wraps another underlying + source via ``read_cutset_from_config(config)``. Such entries accept the + underlying source's keys (e.g. ``cuts_path`` and ``manifest_filepath``) + *in addition to* their own. +* Tarred NeMo manifests support a ``_skipme`` key to omit specific manifest + rows without repacking tars (set to ``True``, ``1``, or a reason string). +* Lhotse Shar is documented in the upstream tutorial: |tutorial_shar|. + +Conversation / multimodal types — when to use which +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Six types yield ``NeMoMultimodalConversation`` from very different sources. +Pick by the shape of your input data: + +.. list-table:: + :header-rows: 1 + :widths: 35 25 40 + + * - Your data + - ``type:`` + - Notes + * - JSONL of multi-turn chats with mixed text/audio turns + - ``multimodal_conversation`` + - Native chat schema; audio turns reference paths or tar members + * - JSONL in ShareGPT chat schema + - ``share_gpt`` + - Adds ShareGPT-specific role/value parsing + * - ShareGPT data packed in WebDataset tar shards + - ``share_gpt_webdataset`` + - Same parsing as ``share_gpt``, reads tarred shards + * - ASR data in NeMo or Lhotse format + - ``lhotse_as_conversation`` + - Builds a 2-turn (instruction+audio / transcript) conversation per cut + * - Spoken-QA data with ``question`` / ``answer`` fields + - ``sqa_as_conversation`` + - Builds a 3-turn (question / audio / answer) conversation per cut + * - Duplex S2S data with user/agent supervisions + - ``s2s_as_conversation`` + - Maps duplex roles onto chat turns + +The last three (``*_as_conversation``) are *transforms*: they delegate to +``read_cutset_from_config(config)`` for the underlying audio source, so the +nested keys like ``manifest_filepath``, ``cuts_path``, or ``shar_path`` +belong on the same entry. + Enabling Lhotse via configuration ---------------------------------- @@ -128,6 +371,16 @@ Some other Lhotse related arguments we support: When ``batch_duration`` is not set, it acts as a static batch size. * ``seed`` sets a random seed for the shuffle buffer. +* ``indexed`` (default ``False``) opts the dataloader into Lhotse's indexed-manifest + path, giving every adapter O(1) random access and graph-token-based exact restore. + Requires ``.idx`` sidecars next to every JSONL/tar file. See + :ref:`indexed-resumable-dataloading` below. + +* ``use_stateful_dataloader`` (default ``False``) swaps PyTorch's + ``DataLoader`` for ``torchdata.stateful_dataloader.StatefulDataLoader`` so + that per-worker iterator state is captured in checkpoints and restored + exactly on resume. Pair with ``indexed: true`` for full O(1) restore. + The full and always up-to-date list of supported options can be found in ``LhotseDataLoadingConfig`` class. .. _asr-dataset-config-format: @@ -147,6 +400,29 @@ The dataset class which converts these examples to tensors can partition the min different processing to each group. For example, you may want to construct different prompts for the model using metadata in ``tags``. +How ``tags`` is applied +^^^^^^^^^^^^^^^^^^^^^^^ + +Every key/value pair in ``tags`` becomes an attribute on every cut produced +by that entry. The dataloader walks the cuts via ``cuts.map(...)`` and runs:: + + for key, val in tags.items(): + setattr(cut, key, val) + +So in your dataset class you read them back as ordinary attributes:: + + def __getitem__(self, cuts): + for cut in cuts: + lang = cut.lang + task = cut.task + ctx = cut.context + ... + +Tags set on a ``group`` apply to every nested entry; tags set on an inner +entry override the outer ones for that source. Conflicts with built-in cut +fields (``id``, ``duration``, ``supervisions``, …) silently overwrite the +built-in — pick tag names that don't collide. + .. note:: When fine-tuning a model that was trained with ``input_cfg`` option, typically you'd only need to override the following options: ``input_cfg=null`` and ``manifest_filepath=path/to/manifest.json``. @@ -384,6 +660,12 @@ Python dataloader instantiation example:: tokenizer=my_tokenizer, ) +**Indexed mode for text/multimodal sources.** All of the parsers above +(``txt_jsonl``, ``nemo_sft_jsonl``, ``multimodal_conversation``, ``share_gpt``, +``share_gpt_webdataset``) accept ``indexed: true`` and integrate with +``StatefulDataLoader``-based exact resume. ``txt`` and ``txt_pair`` are +intentionally streaming-only. See :ref:`indexed-resumable-dataloading`. + **Dataloading and bucketing of text and multimodal data.** When dataloading text or multimodal data, pay attention to the following config options (we provide example values for convenience): * ``use_multimodal_sampling: true`` tells Lhotse to switch from measuring audio duration to measuring token counts; required for text. @@ -419,6 +701,25 @@ To enable bucketing, set ``batch_size: null`` and use the following options: **Joint dataloading of text/audio/multimodal data.** The key strength of this approach is that we can easily combine audio datasets and text datasets, and benefit from every other technique we described in this doc, such as: dynamic data mixing, data weighting, dynamic bucketing, and so on. +Single-config vs. ``multi_config: true`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default the dataloader builds **one** ``CutSet`` and **one** sampler from +the top-level config. Setting ``multi_config: true`` switches to a +**multi-modality** layout where each named sub-block (typically ``audio:`` +and ``text:``) is parsed as its own dataloader config, with its own +sampling/bucketing options, and the per-modality samplers are fused at the +batch level. + +When ``multi_config: true`` is set: + +* Top-level keys (``num_workers``, ``shuffle``, ``seed``, ``sample_rate``, + …) apply globally and are inherited by every sub-block. +* Per-modality overrides — including the ``input_cfg`` itself — go inside + the named sub-block (``audio: ...`` / ``text: ...``). +* The per-modality samplers are combined into one stream by + ``sampler_fusion``. + This approach is described in the `EMMeTT`_ paper. There's also a notebook tutorial called Multimodal Lhotse Dataloading. We construct a separate sampler (with its own batching settings) for each modality, and specify how the samplers should be fused together via the option ``sampler_fusion``: @@ -481,6 +782,162 @@ Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so t .. caution:: We strongly recommend to use multiple shards for text files as well so that different nodes and dataloading workers are able to randomize the order of text iteration. Otherwise, multi-GPU training has a high risk of duplication of text examples. +.. _lhotse-sampling-constraints: + +Sampling constraints +-------------------- + +A :class:`~lhotse.dataset.sampling.base.SamplingConstraint` decides what +"length" means when the sampler packs a mini-batch. NeMo uses four: + +* :class:`~lhotse.dataset.sampling.base.TimeConstraint` — default. + Length = audio duration in seconds. Enforces ``max_duration`` / + ``batch_duration`` / ``quadratic_duration``. +* :class:`~lhotse.dataset.sampling.base.TokenConstraint` — activated by + ``use_multimodal_sampling: true`` for text-only flows. Length = token + count after applying the tokenizer (and optionally the prompt format). + Enforces ``max_tokens`` / ``batch_tokens`` / ``quadratic_factor``. +* ``MultimodalSamplingConstraint`` — Lhotse-style mixed-modality + packing. Activated by setting both ``use_multimodal_sampling: true`` + and a ``token_equivalent_duration`` so audio cuts are measured in + equivalent-token units alongside text. Enforces all of the above plus + ``min_tpt``/``max_tpt`` (token-per-token ratio filtering). +* ``FixedBucketBatchSizeConstraint2D`` — activated automatically when + ``bucket_duration_bins`` is given as a list of ``[duration, tokens]`` + pairs **and** ``bucket_batch_size`` is set. Each bucket gets its own + fixed batch size; this is the layout produced by + ``estimate_duration_bins_2d.py`` and the OOMptimizer. + +You usually don't pick a constraint by name — it's inferred from the +combination of YAML options. The names matter when you read NeMo's source, +extend the system with a custom constraint, or interpret error messages. + +.. _indexed-resumable-dataloading: + +Resumable / indexed dataloading +------------------------------- + +Setting ``indexed: true`` (per-source or top-level) plus +``use_stateful_dataloader: true`` (top-level) opts NeMo's Lhotse dataloader +into Lhotse's indexed iterator graph and torchdata's +``StatefulDataLoader``. The combination gives you: + +* O(1) checkpoint/restore of the *whole* dataloading pipeline — sampler RNG, + bucketer state, multiplexer choice RNG, per-source iterator cursors, and + per-worker prefetch queues — without any replay from the start of the epoch. +* Random access (``__getitem__``) over every supported adapter. + +When set at the top level, ``indexed: true`` is propagated by +``read_dataset_config`` through the ``propagate_attrs`` cascade, so a single +top-level flag covers every nested ``input_cfg`` group. You can still override +it per-source if needed. + +Per-adapter support +^^^^^^^^^^^^^^^^^^^ + +The following ``input_cfg`` types accept ``indexed: true`` today and require an +``.idx`` sidecar next to each data file: + +* ``nemo`` / ``nemo_tarred`` — JSONL manifest gets ``manifest.json.idx``; + every audio tar in ``tarred_audio_filepaths`` gets ``shard.tar.idx``. +* ``lhotse`` (plain) — ``cuts.jsonl`` gets ``cuts.jsonl.idx``. +* ``lhotse_shar`` — every uncompressed ``cuts..jsonl`` and field tar + inside the Shar dir. +* ``parquet`` — no sidecar required, but the file must expose row-group + statistics (the default for files written by pyarrow / pandas). +* ``txt_jsonl`` — every file in ``paths``. +* ``multimodal_conversation`` and ``share_gpt`` — JSONL manifest plus optional + audio tars in ``tarred_audio_filepaths``. +* ``share_gpt_webdataset`` — every ``shard-*.tar`` inside ``data_dir``. + +``txt`` and ``txt_pair`` remain streaming-only (no random-access support). + +Two caveats to be aware of: + +* ``indexed: true`` is incompatible with ``extra_fields`` and ``slice_length`` + on ``nemo``/``nemo_tarred``: those features mutate or expand cuts in a way + that has no stable index. Pre-process the manifest offline if you need them + in an indexed pipeline. +* Only **uncompressed** files can be indexed (no ``.jsonl.gz``, + ``.tar.gz``, etc.) and only files on a backend that supports indexed reads + (local FS, S3-compatible object stores, AIStore). + +Building ``.idx`` sidecars +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Two equivalent ways: + +1. Lhotse's CLI per file:: + + lhotse index jsonl path/to/cuts.jsonl + lhotse index tar path/to/shard.tar + lhotse index shar path/to/shar_dir/ + +2. NeMo's batch helper that takes a config and indexes everything it + references in one shot:: + + python scripts/dataloading/build_indexes.py path/to/input_cfg.yaml + + The script walks ``input_cfg`` (including nested ``group`` entries and + per-entry YAML references), dispatches the right tar layout for each + adapter (NeMo one-member-per-sample vs. WebDataset/Shar pair format), and + skips files that already have an up-to-date ``.idx``. Use ``--force`` to + rebuild, ``--workers N`` for parallelism, ``--dry-run`` to preview. + +End-to-end YAML example +^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + model: + train_ds: + # Top-level switches enable indexed restore for every source below. + indexed: true + use_stateful_dataloader: true + force_finite: true + force_map_dataset: true + + sample_rate: 16000 + num_workers: 4 + seed: 42 + shard_seed: randomized + + # Bucketing and the rest of the dataloader knobs work exactly as before. + use_bucketing: true + num_buckets: 30 + batch_duration: 1100 + quadratic_duration: 30 + + input_cfg: + - type: nemo_tarred + manifest_filepath: /data/asr/manifest__OP_0..127_CL_.jsonl + tarred_audio_filepaths: /data/asr/audio__OP_0..127_CL_.tar + weight: 0.7 + - type: lhotse + cuts_path: /data/extra/cuts.jsonl + weight: 0.3 + +Resume contract +^^^^^^^^^^^^^^^ + +When ``use_stateful_dataloader: true`` is set, Lightning's checkpoint will +contain the full lhotse iterator graph state under the dataloader key. On +resume: + +* iterator positions advance to where they were at save time (no replay from + position 0); +* ``set_epoch`` is a no-op while restored state is pending, so the resumed run + continues the same epoch instead of starting a new one; +* ``num_workers`` and ``world_size`` must match between save and restore (a + hard requirement of ``StatefulDataLoader``). + +Non-indexed pipelines fall back to Lhotse's ``_fast_forward()`` replay (O(N) +in batches consumed before the checkpoint) and require ``num_workers`` only to +be consistent for replay-based restore — not exact restore. + +For the iterator graph contract itself, see Lhotse's +`indexed manifests guide `_. + Pre-computing bucket duration bins ------------------------------------ @@ -594,7 +1051,7 @@ For Canary-1B, we'll also provide the special tokens tokenizer. Example: input_cfg.yaml Pushing GPU utilization to the limits with bucketing and OOMptimizer -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The default approach of specifying a ``batch_duration``, ``bucket_duration_bins`` and ``quadratic_duration`` is quite flexible, but is not maximally efficient. We observed that in practice it often leads to under-utilization @@ -685,3 +1142,301 @@ Other, more exotic configurations: * With ``seed="trng"``, the base random seed itself will be drawn using a TRNG. It will be different on each GPU training process. This setting is not recommended. * With ``seed="randomized"``, the base random seed is set to Python's global RNG seed. It might be different on each GPU training process. This setting is not recommended. + +Train vs. validation / test configs +----------------------------------- + +The training and validation/test sections of a NeMo recipe use the same +underlying dataloader builder but have a different shape and a different +default behavior. + +**Training (``train_ds``).** A single config that produces one infinite +``CutSet``. The dataloader is wrapped to never run out of data, so +``trainer.max_steps`` (and ``limit_train_batches`` for tarred sources) +controls the run length: + +.. code-block:: yaml + + model: + train_ds: + sample_rate: 16000 + num_workers: 4 + shuffle: true + use_bucketing: true + num_buckets: 30 + batch_duration: 1100 + input_cfg: + - type: nemo_tarred + manifest_filepath: /data/asr/manifest__OP_0..127_CL_.json + tarred_audio_filepaths: /data/asr/audio__OP_0..127_CL_.tar + +**Validation / test (``validation_ds`` / ``test_ds``).** A *named* dict of +configs — one per evaluation set — that produces finite iteration: + +.. code-block:: yaml + + model: + validation_ds: + sample_rate: 16000 + batch_size: 16 + # Per-set entries; keys become the metric prefixes in logging. + datasets: + dev_clean: + cuts_path: /data/dev-clean/cuts.jsonl + dev_other: + cuts_path: /data/dev-other/cuts.jsonl + +The most common eval-side overrides: + +* ``shuffle: false`` — deterministic order. +* ``force_finite: true`` — break out of the infinite-mux that's safe for + training but would loop forever in eval. +* ``use_bucketing: false`` — bucketing trades padding for randomness; on a + small eval set the savings are negligible and a fixed batch size makes + results easier to interpret. +* ``num_workers: 0`` (or a small number) — eval is short, the worker + startup cost matters more. + +When the model code expects a single eval set, use the plain ``cuts_path`` / +``manifest_filepath`` form at the same level as ``train_ds`` instead of the +``datasets:`` dict. + +Preparing your data +------------------- + +Three minimal recipes covering the main on-disk formats. + +**NeMo manifest** — one JSON object per line, fields read by ``LazyNeMoIterator``:: + + {"audio_filepath": "/data/utt_0001.wav", "duration": 3.42, "text": "hello world", "lang": "en"} + {"audio_filepath": "/data/utt_0002.wav", "duration": 5.10, "text": "another example", "lang": "en"} + +For tarred NeMo manifests, see +``scripts/speech_recognition/convert_to_tarred_audio_dataset.py`` in the NeMo +repo. + +**Lhotse cuts JSONL** — build a ``CutSet`` from raw recordings + supervisions: + +.. code-block:: python + + from lhotse import CutSet, Recording, SupervisionSegment + + cuts = [] + for path, transcript in pairs: + rec = Recording.from_file(path) + sup = SupervisionSegment( + id=rec.id, recording_id=rec.id, + start=0.0, duration=rec.duration, + text=transcript, language="en", + ) + cut = rec.to_cut() + cut.supervisions = [sup] + cuts.append(cut) + + CutSet.from_cuts(cuts).to_file("cuts.jsonl") # uncompressed! + +For Lhotse Shar (sharded archive), see the upstream tutorial: |tutorial_shar|. + +**Parquet** — write a ``pyarrow`` table with the column names the +``LazyParquetIterator`` reads (``audio``, ``text``, ``duration``, +optional ``lang``): + +.. code-block:: python + + import pyarrow as pa, pyarrow.parquet as pq + + table = pa.table({ + "audio": [open(p, "rb").read() for p in paths], + "text": transcripts, + "duration": durations, + "lang": ["en"] * len(paths), + }) + pq.write_table(table, "shard_000.parquet") # row-group stats kept by default + +Once your manifests are written, build the indexed sidecars in one shot:: + + python scripts/dataloading/build_indexes.py path/to/input_cfg.yaml + +See :ref:`indexed-resumable-dataloading` for the resumable side. + +.. _lhotse-storage-backends: + +Storage backends: local, object store, AIStore +---------------------------------------------- + +Every input path the dataloader reads goes through Lhotse's ``open_best``, +which routes file paths and URIs to the right backend automatically: + +* **Local files** — paths like ``/data/...`` work out of the box, no + configuration needed. +* **Generic object stores via ``smart_open``** — ``s3://``, ``gs://``, + ``http://``, ``https://`` URIs work after ``pip install smart_open``. + Authentication uses the underlying SDK's defaults (e.g. AWS env vars). +* **AIStore** — ``ais://bucket/key`` URIs work after ``pip install aistore`` + and ``export AIS_ENDPOINT=http://...``. Optional tuning env vars + ``AIS_CONNECT_TIMEOUT`` and ``AIS_READ_TIMEOUT`` are honored by the SDK. + +The same routing applies to ``.idx`` sidecars: they are read and written +next to the data file, so the backend must accept writes at that location +or the indexes need to be pre-built locally and uploaded. + +AIStore GetBatch (separate optimization) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For tarred multimodal-conversation manifests, NeMo also supports AIStore's +batched object-fetch API (``GetBatch``) via ``USE_AIS_GET_BATCH=true``, +which issues one batched fetch per minibatch instead of per-cut tar reads. +This is independent of using AIStore as a generic backend — see +:doc:`speechlm2/datasets` for the speech-LM-specific details, including +how it composes with ``indexed: true``. + +.. _lhotse-extension-hooks: + +Registering a custom format +--------------------------- + +Adding a new ``type:`` to the ``input_cfg`` registry is one decorator and +one function: + +.. code-block:: python + + from nemo.collections.common.data.lhotse.cutset import data_type_parser + from lhotse import CutSet + + @data_type_parser("my_format") + def read_my_format(config) -> tuple[CutSet, bool]: + cuts = CutSet(MyAdapter(path=config.path, ...)) + is_tarred = True # True ⇒ IterableDataset path; False ⇒ map-style + return cuts, is_tarred + +The parser must accept arbitrary keys: ``read_dataset_config`` cascades +options like ``indexed``, ``shard_seed``, ``metadata_only``, +``force_finite``, ``audio_locator_tag`` from the top of the YAML down into +every entry via ``propagate_attrs``. Missing keys should fall back to +sensible defaults via ``config.get(...)``. + +To make ``MyAdapter`` participate in the indexed/resumable path +(:ref:`indexed-resumable-dataloading`), implement Lhotse's +:class:`~lhotse.lazy.IteratorNode` contract — see +`indexed manifests guide `_ +for the requirements. + +Common pitfalls +--------------- + +The most common foot-guns when standing up a NeMo Lhotse recipe: + +1. **Forgetting** ``trainer.use_distributed_sampler=false``. NeMo's Lhotse + integration handles distributed sampling itself; leaving Lightning's + default on causes silent batch duplication across DP ranks. + +2. **No** ``max_steps`` **with tarred / Shar data.** Tarred sources are + infinite by design, so without ``trainer.max_steps`` (and + ``limit_train_batches`` for the periodic validation cadence) training + never completes the first "epoch". Always set both. + +3. **Compressed inputs cannot be indexed.** ``.jsonl.gz`` and ``.tar.gz`` + work for streaming, but ``indexed: true`` requires uncompressed, + seekable files. Re-extract or re-write before building ``.idx``. + +4. **Mismatched** ``num_workers`` / ``world_size`` **on resume.** Exact + per-worker resume with ``StatefulDataLoader`` requires both to match + between save and restore. Replay-based restore with the regular + ``DataLoader`` is more lenient. + +5. ``indexed: true`` **is incompatible with** ``extra_fields`` **and** + ``slice_length`` on ``nemo`` / ``nemo_tarred``. Both expand or rewrite + cuts in a way that has no stable index. Pre-process the manifest + offline if you need them in an indexed pipeline. + +6. ``shard_seed: "trng"`` **deadlocks under TP/PP.** Tensor- and pipeline- + parallel ranks must see the same shard order, but ``"trng"`` draws an + independent seed per worker. Use ``shard_seed: "randomized"`` whenever + you have model parallelism on top of DDP. + +7. **Missing** ``force_finite: true`` **on validation.** Validation configs + that reuse training infrastructure inherit the infinite-mux behavior; + without ``force_finite: true`` the validation loop never terminates. + +.. _lhotse-config-reference: + +``LhotseDataLoadingConfig`` field reference +------------------------------------------- + +The complete option schema lives in ``LhotseDataLoadingConfig`` +(``nemo/collections/common/data/lhotse/dataloader.py``). It carries ~80 +fields; the categorization below mirrors the source order and groups +options by what they control. + +**Inputs.** ``input_cfg``, ``manifest_filepath``, +``tarred_audio_filepaths``, ``cuts_path``, ``shar_path``, +``skip_missing_manifest_entries``. + +**Sampling — basic.** ``batch_size``, ``batch_duration``, +``quadratic_duration``, ``min_duration``, ``max_duration``, ``min_tps``, +``max_tps``. + +**Sampling — bucketing.** ``use_bucketing``, ``num_buckets``, +``bucket_duration_bins``, ``bucket_batch_size``, ``bucket_buffer_size``, +``num_cuts_for_bins_estimate``, ``concurrent_bucketing``. + +**Sampling — multimodal.** ``use_multimodal_sampling``, ``prompt_format``, +``pretokenize``, ``audio_locator_tag``, ``token_equivalent_duration``, +``batch_tokens``, ``quadratic_factor``, ``min_tokens``, ``max_tokens``, +``min_tpt``, ``max_tpt``, ``measure_total_length``. + +**Sampling — fusion (multi-config).** ``multi_config``, ``sampler_fusion``, +``sampler_weights``. + +**Indexed / resumable.** ``indexed``, ``use_stateful_dataloader``. See +:ref:`indexed-resumable-dataloading`. + +**Mixing & weighting.** ``reweight_temperature``, ``max_open_streams``. + +**I/O & distributed.** ``num_workers``, ``pin_memory``, ``shard_seed``, +``seed``, ``shuffle``, ``shuffle_buffer_size``, ``drop_last``, +``force_finite``, ``force_map_dataset``, ``force_iterable_dataset``, +``metadata_only``, ``cuda_expandable_segments``. + +**On-the-fly augmentation.** + +* Speed/RIR — ``perturb_speed``, ``rir_enabled``, ``rir_path``, ``rir_prob``. +* Noise — ``noise_path``, ``noise_snr``, ``noise_mix_prob``. +* Lowpass — ``lowpass_enabled``, ``lowpass_frequencies_interval``, + ``lowpass_prob``. +* Compression — ``compression_enabled``, ``compression_prob``, + ``compression_level_interval``, ``compression_codecs``, + ``compression_codec_weights``, ``compression_enable_for_custom_fields``. +* Clipping — ``clipping_enabled``, ``clipping_gain_db``, + ``clipping_normalize``, ``clipping_oversampling``, ``clipping_prob``, + ``clipping_prob_hard``. +* Concatenation — ``concatenate_samples``, ``concatenate_gap_seconds``, + ``concatenate_duration_factor``, ``concatenate_merge_supervisions``, + ``db_norm``. + +**Cut transforms.** ``truncate_duration``, ``truncate_offset_type``, +``cut_into_windows_duration``, ``cut_into_windows_hop``, +``pad_min_duration``, ``pad_direction``, ``cut_text_into_windows_tokens``, +``keep_excessive_supervisions``. + +**Field-name overrides.** ``text_field``, ``lang_field``, +``channel_selector``, ``sample_rate``. + +**Filtering.** ``max_cer``, ``min_context_speaker_similarity``, ``keep``. + +For exact types and defaults, see the dataclass definition in the source +file — it is the single source of truth. + +See also +-------- + +* :doc:`speechlm2/datasets` — speech-LM-specific data classes, AIStore + GetBatch with indexed mode, and the SpeechLM ``DataModule`` resume + contract. +* :doc:`asr/datasets` — ASR-specific data preparation conventions. +* :doc:`audio/datasets` — audio (codec, enhancement) data flows. +* `Lhotse PyTorch Datasets `_ + — upstream sampler API, ``StatefulDataLoader`` integration, custom RNG + state in batch transforms. +* `Lhotse indexed manifests `_ + — the iterator-graph contract that makes O(1) restore work. diff --git a/docs/source/speechlm2/datasets.rst b/docs/source/speechlm2/datasets.rst index 2006fbc59cc6..f408458b86c3 100644 --- a/docs/source/speechlm2/datasets.rst +++ b/docs/source/speechlm2/datasets.rst @@ -4,6 +4,16 @@ Datasets The speechlm2 collection supports datasets that contain both audio and text data for training models that can understand speech and generate appropriate responses. This section describes the dataset format, preparation, and usage with the speechlm2 models. +.. seealso:: + + :doc:`/dataloaders` is the canonical reference for the underlying Lhotse + dataloader: ``input_cfg`` shape, supported formats, sampling/bucketing + options, indexed manifests + resumable dataloading, and + ``LhotseDataLoadingConfig`` field schema. The page below covers what's + speech-LM-specific on top of that — datamodule resume contract, + AIStore GetBatch, conversation type semantics in the SALM/duplex + recipes. + Dataset Format -------------- @@ -228,6 +238,27 @@ When enabled: Leave the env var unset to keep the original tar-iterating loader. +Combining with ``indexed: true`` +"""""""""""""""""""""""""""""""" + +``USE_AIS_GET_BATCH=true`` coexists with ``indexed: true`` on +``LazyNeMoTarredIterator`` (and on the multimodal-conversation adapters). +Indexed mode keeps the JSONL-driven O(1) global indexing and graph-token +checkpointing, while AIStore GetBatch handles the actual audio fetch: + +* The audio-tar ``.idx`` sidecar is **not** required when GetBatch is enabled + — the iterator skips opening tar files entirely and emits URL-backed cuts + whose ``AudioSource`` points at ``{tar_path}/{audio_filename}`` + (``type="url"`` for ``ais://...`` paths, ``type="file"`` otherwise). +* Manifest JSONLs still need their ``.idx`` sidecars; they drive the indexed + iterator graph and the ``state_dict`` / ``load_state_dict`` round-trip. +* Audio bytes are fetched lazily by ``AudioSamples(use_batch_loader=True)`` at + collation time, which issues one batched GetBatch request per minibatch. + +Use this combination when shards live on AIStore and you want both the +network efficiency of GetBatch and the exact-resume guarantees of the +indexed/stateful pipeline. + DuplexSTTDataset **************** @@ -264,6 +295,39 @@ The DataModule takes care of: 1. Setting up proper data parallel ranks for dataloaders 2. Instantiating the dataloaders with configuration from YAML 3. Managing multiple datasets for validation/testing +4. Persisting the train dataloader's iterator state across checkpoints + (when ``use_stateful_dataloader: true``) + +Checkpointed / resumable training +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The DataModule caches the train dataloader on first ``train_dataloader()`` +call and exposes ``state_dict()`` / ``load_state_dict()`` that delegate to the +cached dataloader when it supports them. Lightning's trainer wires those into +every checkpoint automatically, so an experiment configured with:: + + data: + train_ds: + indexed: true + use_stateful_dataloader: true + ... + +resumes O(1) — sampler RNG, bucketer state, multiplexer choice RNG, +per-source iterator cursors, and per-worker prefetch queues are all restored +exactly without replay. + +With a regular ``DataLoader`` (``use_stateful_dataloader`` unset or +``False``) ``state_dict``/``load_state_dict`` become no-ops and resume falls +back to Lhotse's ``_fast_forward()`` replay path. + +Two constraints to keep in mind across save/restore: + +* ``num_workers`` and ``world_size`` must match between save and restore + (a hard requirement of ``StatefulDataLoader``). +* All data files must be **uncompressed** and accompanied by ``.idx`` + sidecars. Build them in one shot with ``scripts/dataloading/build_indexes.py`` + (see :ref:`indexed-resumable-dataloading` in the main Lhotse dataloading + guide). Bucketing for Efficient Training ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/scripts/dataloading/build_indexes.py b/scripts/dataloading/build_indexes.py new file mode 100644 index 000000000000..9a39d38a0cda --- /dev/null +++ b/scripts/dataloading/build_indexes.py @@ -0,0 +1,377 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Build O(1)-restore index sidecars for an arbitrary NeMo Lhotse ``input_cfg``. + +Walks a NeMo dataloading config (``input_cfg`` YAML, including nested ``group`` +entries and per-entry YAML references), discovers every JSONL/tar file an +indexed dataloader will need, and creates the corresponding ``.idx`` sidecars +next to each data file. + +Two tar layouts are dispatched correctly: + +* NeMo tarred audio (one regular member per sample, name-keyed) — uses + ``nemo.collections.common.data.lhotse.indexed_adapters.create_tar_index`` + which records one offset per *basename group*. +* WebDataset/Shar tars (json + payload pairs) — uses + ``lhotse.indexing.create_tar_index`` which records one offset per *member + pair*. + +Local files and remote URIs are both supported via lhotse's ``open_best`` +(which routes to ``smart_open`` / AIStore SDK when available). The ``.idx`` is +written next to its source path, so the storage backend must accept writes at +that location — for read-only object stores, materialize the data locally +first or pre-build indexes at upload time. + +Examples:: + + # Build indexes for everything referenced by an input_cfg.yaml. + python scripts/dataloading/build_indexes.py path/to/input_cfg.yaml + + # Multiple configs at once. + python scripts/dataloading/build_indexes.py train.yaml validation.yaml + + # Show what would be built without writing anything. + python scripts/dataloading/build_indexes.py --dry-run path/to/input_cfg.yaml + + # Rebuild even when an .idx already exists; parallelize across 16 workers. + python scripts/dataloading/build_indexes.py --force --workers 16 path/to/input_cfg.yaml +""" + +import logging +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Iterable, Iterator + +import click +from omegaconf import DictConfig, ListConfig, OmegaConf + +from nemo.collections.common.data.lhotse.indexed_adapters import ( + create_tar_index as create_nemo_tar_index, +) +from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths + + +# --------------------------------------------------------------------------- # +# Tar layout taxonomy. +# --------------------------------------------------------------------------- # +# NEMO_TAR — one regular member per sample, indexed by basename. Used by +# nemo / nemo_tarred / multimodal_conversation / share_gpt audio +# tars (read via IndexedTarMemberReader). +# WDS_TAR — WebDataset-style: each sample is a pair of consecutive members +# (e.g. {N}.json + {N}.