diff --git a/nemo_curator/stages/audio/__init__.py b/nemo_curator/stages/audio/__init__.py index cec7a119f6..0c01d7b9f9 100644 --- a/nemo_curator/stages/audio/__init__.py +++ b/nemo_curator/stages/audio/__init__.py @@ -45,6 +45,13 @@ SpeakerSeparationStage, VADSegmentationStage, ) +from nemo_curator.stages.audio.text_filtering import ( + FastTextLIDStage, + FinalizeFieldsStage, + InitializeFieldsStage, + RegexSubstitutionStage, + WhisperHallucinationStage, +) __all__ = [ "ALMDataBuilderStage", @@ -60,4 +67,9 @@ "TimestampMapperStage", "UTMOSFilterStage", "VADSegmentationStage", + "FastTextLIDStage", + "FinalizeFieldsStage", + "InitializeFieldsStage", + "RegexSubstitutionStage", + "WhisperHallucinationStage", ] diff --git a/nemo_curator/stages/audio/text_filtering/__init__.py b/nemo_curator/stages/audio/text_filtering/__init__.py new file mode 100644 index 0000000000..480c441bbc --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text filtering stages for ASR postprocessing.""" + +from nemo_curator.stages.audio.text_filtering.fasttext_lid import FastTextLIDStage +from nemo_curator.stages.audio.text_filtering.finalize_fields import FinalizeFieldsStage +from nemo_curator.stages.audio.text_filtering.initialize_fields import InitializeFieldsStage +from nemo_curator.stages.audio.text_filtering.regex_substitution import RegexSubstitutionStage +from nemo_curator.stages.audio.text_filtering.whisper_hallucination import WhisperHallucinationStage + +__all__ = [ + "FastTextLIDStage", + "FinalizeFieldsStage", + "InitializeFieldsStage", + "RegexSubstitutionStage", + "WhisperHallucinationStage", +] diff --git a/nemo_curator/stages/audio/text_filtering/fasttext_lid.py b/nemo_curator/stages/audio/text_filtering/fasttext_lid.py new file mode 100644 index 0000000000..c5d38d3103 --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/fasttext_lid.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import urllib.request +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +_FASTTEXT_MODEL_URLS: dict[str, str] = { + "lid.176.bin": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", + "lid.176.ftz": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz", +} +_DEFAULT_CACHE_DIR = os.path.expanduser("~/.cache/nemo_curator/fasttext") + + +@dataclass +class FastTextLIDStage(ProcessingStage[AudioTask, AudioTask]): + """Language identification using FastText; flags non-target-language entries with skip_me=1. + + Wraps the existing ``FastTextLangId`` filter for model loading and scoring, + adding AudioTask field access and optional model download by name. + + ``model_path`` can be: + - An absolute path to a local ``.bin`` or ``.ftz`` file. + - A known model name (``lid.176.bin`` or ``lid.176.ftz``), which is + downloaded to ``~/.cache/nemo_curator/fasttext/`` on first use. + """ + + model_path: str = "" + target_lang: str = "en" + min_lang_prob: float = 0.3 + text_key: str = "cleaned_text" + skip_me_key: str = "skip_me" + name: str = "FastTextLID" + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + _lid: Any = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.model_path: + msg = "model_path is required for FastTextLIDStage" + raise ValueError(msg) + + def _resolve_model_path(self) -> str: + if os.path.isfile(self.model_path): + return self.model_path + if self.model_path in _FASTTEXT_MODEL_URLS: + cache_path = os.path.join(_DEFAULT_CACHE_DIR, self.model_path) + if os.path.isfile(cache_path): + return cache_path + os.makedirs(_DEFAULT_CACHE_DIR, exist_ok=True) + url = _FASTTEXT_MODEL_URLS[self.model_path] + logger.info(f"FastTextLIDStage: downloading {self.model_path} from {url}") + urllib.request.urlretrieve(url, cache_path) # noqa: S310 + return cache_path + msg = ( + f"model_path '{self.model_path}' is not a valid file path and not a known model name. " + f"Known names: {list(_FASTTEXT_MODEL_URLS)}" + ) + raise ValueError(msg) + + def setup(self, worker_metadata: Any = None) -> None: + from nemo_curator.stages.text.filters.fasttext.fasttext_filters import FastTextLangId + + resolved = self._resolve_model_path() + self._lid = FastTextLangId(model_path=resolved, min_langid_score=0.0) + self._lid.load_model() + logger.info(f"FastTextLIDStage: loaded model from {resolved}") + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key, self.skip_me_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.skip_me_key] + + def process(self, task: AudioTask) -> AudioTask: + if self._lid is None: + logger.warning( + f"FastTextLIDStage ({self.name}): setup() was not called before process(). " + "Calling setup() now — check that your executor invokes setup() on each worker." + ) + self.setup() + text = task.data[self.text_key] + if not isinstance(text, str): + return task + text = text.strip().replace("\n", " ") + if not text: + task.data[self.skip_me_key] = 1 + return task + result_str = self._lid.score_document(text) + score_list = eval(result_str) # noqa: S307 — output of our own FastText model + prob = float(score_list[0]) + lang = str(score_list[1]).lower() + if lang != self.target_lang.lower() or prob < self.min_lang_prob: + task.data[self.skip_me_key] = 1 + return task diff --git a/nemo_curator/stages/audio/text_filtering/finalize_fields.py b/nemo_curator/stages/audio/text_filtering/finalize_fields.py new file mode 100644 index 0000000000..8e9ec1712d --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/finalize_fields.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + + +@dataclass +class FinalizeFieldsStage(ProcessingStage[AudioTask, AudioTask]): + """Rename and drop fields to produce the final manifest schema. + + - Renames ``source_text_key`` (``text``) → ``v1_text_key`` (``v1_text``). + - Renames ``cleaned_text_key`` (``cleaned_text``) → ``source_text_key`` (``text``). + - Drops all keys listed in ``drop_keys`` (silently ignores missing keys). + """ + + source_text_key: str = "text" + v1_text_key: str = "v1_text" + cleaned_text_key: str = "cleaned_text" + drop_keys: list[str] = field(default_factory=lambda: ["pnc", "itn", "timestamp"]) + name: str = "FinalizeFields" + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.source_text_key, self.cleaned_text_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.v1_text_key, self.source_text_key] + + def process(self, task: AudioTask) -> AudioTask: + if self.source_text_key in task.data: + task.data[self.v1_text_key] = task.data.pop(self.source_text_key) + if self.cleaned_text_key in task.data: + task.data[self.source_text_key] = task.data.pop(self.cleaned_text_key) + for key in self.drop_keys: + task.data.pop(key, None) + return task diff --git a/nemo_curator/stages/audio/text_filtering/initialize_fields.py b/nemo_curator/stages/audio/text_filtering/initialize_fields.py new file mode 100644 index 0000000000..9feee65c92 --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/initialize_fields.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + + +@dataclass +class InitializeFieldsStage(ProcessingStage[AudioTask, AudioTask]): + """Copy pred_text into cleaned_text and initialize skip_me=0. + + This stage sets up the two fields that all downstream text-filtering + stages depend on, leaving the original pred_text field intact. + """ + + pred_text_key: str = "pred_text" + cleaned_text_key: str = "cleaned_text" + skip_me_key: str = "skip_me" + name: str = "InitializeFields" + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.pred_text_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.cleaned_text_key, self.skip_me_key] + + def process(self, task: AudioTask) -> AudioTask: + task.data[self.cleaned_text_key] = task.data[self.pred_text_key] + task.data[self.skip_me_key] = 0 + return task diff --git a/nemo_curator/stages/audio/text_filtering/regex_substitution.py b/nemo_curator/stages/audio/text_filtering/regex_substitution.py new file mode 100644 index 0000000000..cd12bb0f98 --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/regex_substitution.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass, field +from typing import Any + +import yaml +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + + +@dataclass +class RegexSubstitutionStage(ProcessingStage[AudioTask, AudioTask]): + """Apply a sequence of regex substitutions to a text field in each AudioTask. + + Rules are loaded from a YAML file containing a list of dicts with + ``pattern`` and ``repl`` keys (and an optional ``count`` key). + After all substitutions, if the result is empty the entry is flagged + with ``skip_me=1``. + """ + + regex_params_yaml: str = "" + text_key: str = "cleaned_text" + skip_me_key: str = "skip_me" + name: str = "RegexSubstitution" + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + _rules: list[dict[str, Any]] = field(default_factory=list, init=False, repr=False) + _setup_called: bool = field(default=False, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.regex_params_yaml: + msg = "regex_params_yaml is required for RegexSubstitutionStage" + raise ValueError(msg) + + def setup(self, worker_metadata: Any = None) -> None: + with open(self.regex_params_yaml, encoding="utf-8") as f: + self._rules = yaml.safe_load(f) + self._setup_called = True + logger.info(f"RegexSubstitutionStage: loaded {len(self._rules)} rules from {self.regex_params_yaml}") + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key, self.skip_me_key] + + def process(self, task: AudioTask) -> AudioTask: + if not self._setup_called: + logger.warning( + f"RegexSubstitutionStage ({self.name}): setup() was not called before process(). " + "Calling setup() now — check that your executor invokes setup() on each worker." + ) + self.setup() + text = task.data[self.text_key] + if not isinstance(text, str): + return task + text = " " + text + " " + for rule in self._rules: + text = re.sub(rule["pattern"], rule["repl"], text, count=rule.get("count", 0)) + text = re.sub(r"\s+", " ", text).strip() + task.data[self.text_key] = text + if not text: + task.data[self.skip_me_key] = 1 + return task diff --git a/nemo_curator/stages/audio/text_filtering/whisper_hallucination.py b/nemo_curator/stages/audio/text_filtering/whisper_hallucination.py new file mode 100644 index 0000000000..9a5c1cf251 --- /dev/null +++ b/nemo_curator/stages/audio/text_filtering/whisper_hallucination.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + + +@dataclass +class WhisperHallucinationStage(ProcessingStage[AudioTask, AudioTask]): + """Detect common Whisper hallucination patterns and flag entries with skip_me=1. + + Five checks are applied: + - Repeated n-grams: low lexical diversity (unique-word ratio <= threshold). + - Long word: an abnormally long word or a word much longer than its neighbours. + - Frequent single phrase: the full transcript matches a known hallucination phrase. + - Low char rate: word-chars / duration <= char_rate_threshold (sparse text over long audio). + - High char rate: word-chars / duration > max_char_rate (impossible speech rate; short audio + with dense confabulated text, e.g. Whisper generating a full sentence over 0.1 s). + + If any check triggers, ``skip_me`` is set to 1 (existing value of 1 is preserved). + No intermediate flag fields are added to the task. + """ + + common_hall_file: str = "" + unique_words_threshold: float = 0.4 + long_word_threshold: int = 25 + long_word_rel_threshold: float = 3.0 + char_rate_threshold: float = 4.0 + max_char_rate: float = 40.0 + duration_key: str = "duration" + text_key: str = "cleaned_text" + skip_me_key: str = "skip_me" + name: str = "WhisperHallucination" + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + _phrases: set[str] = field(default_factory=set, init=False, repr=False) + _setup_called: bool = field(default=False, init=False, repr=False) + _n_processed: int = field(default=0, init=False, repr=False) + _n_flagged: int = field(default=0, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.common_hall_file: + msg = "common_hall_file is required for WhisperHallucinationStage" + raise ValueError(msg) + + def setup(self, worker_metadata: Any = None) -> None: + with open(self.common_hall_file, encoding="utf-8") as f: + phrases = {line.strip() for line in f if line.strip()} + self._phrases = phrases + self._setup_called = True + logger.info(f"WhisperHallucinationStage: loaded {len(phrases)} phrases from {self.common_hall_file}") + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key, self.skip_me_key, self.duration_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.skip_me_key] + + def _repeated_ngrams(self, words: list[str]) -> bool: + if not words: + return False + return len(set(words)) / len(words) <= self.unique_words_threshold + + def _long_word(self, words: list[str]) -> bool: + if not words: + return False + lengths = sorted(len(w) for w in words) + if lengths[-1] >= self.long_word_threshold: + return True + if len(lengths) > 1 and lengths[-2] > 0: + return (lengths[-1] - lengths[-2]) / lengths[-2] >= self.long_word_rel_threshold + return False + + # Phrases shorter than this are matched exactly; longer ones also match as prefixes. + _PREFIX_MATCH_MIN_LEN: int = 8 + + def _frequent_single_word(self, text: str) -> bool: + cleaned = text.strip().replace(".", "").replace("?", "").replace("!", "") + if cleaned in self._phrases: + return True + return any( + len(phrase) >= self._PREFIX_MATCH_MIN_LEN and cleaned.startswith(phrase) + for phrase in self._phrases + ) + + def _low_char_rate(self, words: list[str], duration: float) -> bool: + if duration <= 0: + return False + chars = sum(len(w) for w in words) + return chars / duration <= self.char_rate_threshold + + def _high_char_rate(self, words: list[str], duration: float) -> bool: + if duration <= 0: + return False + chars = sum(len(w) for w in words) + return chars / duration > self.max_char_rate + + def process(self, task: AudioTask) -> AudioTask: + if not self._setup_called: + logger.warning( + f"WhisperHallucinationStage ({self.name}): setup() was not called before process(). " + "Calling setup() now — check that your executor invokes setup() on each worker." + ) + self.setup() + text = task.data[self.text_key] + if not isinstance(text, str): + return task + words = text.split() + duration = task.data.get(self.duration_key, 0.0) or 0.0 + + repeated = self._repeated_ngrams(words) + long_w = self._long_word(words) + phrase = self._frequent_single_word(text) + low_rate = self._low_char_rate(words, duration) + high_rate = self._high_char_rate(words, duration) + + self._n_processed += 1 + if repeated or long_w or phrase or low_rate or high_rate: + self._n_flagged += 1 + reasons = [ + name + for name, hit in [ + ("repeated_ngrams", repeated), + ("long_word", long_w), + ("phrase_match", phrase), + ("low_char_rate", low_rate), + ("high_char_rate", high_rate), + ] + if hit + ] + logger.debug( + f"[{self.name}] flagged ({','.join(reasons)}) dur={duration:.2f}s: {text[:80]!r}" + ) + task.data[self.skip_me_key] = 1 + return task + + def teardown(self) -> None: + logger.info( + f"[{self.name}] done — processed={self._n_processed}, flagged={self._n_flagged}" + ) diff --git a/tests/stages/audio/text_filtering/__init__.py b/tests/stages/audio/text_filtering/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/stages/audio/text_filtering/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/stages/audio/text_filtering/test_fasttext_lid.py b/tests/stages/audio/text_filtering/test_fasttext_lid.py new file mode 100644 index 0000000000..f861660715 --- /dev/null +++ b/tests/stages/audio/text_filtering/test_fasttext_lid.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from nemo_curator.stages.audio.text_filtering.fasttext_lid import FastTextLIDStage +from nemo_curator.tasks import AudioTask + + +def _make_stage(label: str, prob: float) -> FastTextLIDStage: + """Create a stage with a mocked FastTextLangId that returns the given label/prob.""" + stage = FastTextLIDStage(model_path="/fake/model.bin") + mock_lid = MagicMock() + mock_lid.score_document.return_value = str([prob, label]) + stage._lid = mock_lid + return stage + + +def test_correct_lang_and_high_prob_passes() -> None: + stage = _make_stage("EN", 0.95) + task = AudioTask(data={"cleaned_text": "hello world", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 0 + + +def test_wrong_lang_sets_skip_me() -> None: + stage = _make_stage("FR", 0.95) + task = AudioTask(data={"cleaned_text": "bonjour monde", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_low_prob_sets_skip_me() -> None: + stage = _make_stage("EN", 0.1) + task = AudioTask(data={"cleaned_text": "hello", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_correct_lang_exactly_at_threshold_passes() -> None: + stage = FastTextLIDStage(model_path="/fake/model.bin", min_lang_prob=0.3) + mock_lid = MagicMock() + mock_lid.score_document.return_value = str([0.3, "EN"]) + stage._lid = mock_lid + task = AudioTask(data={"cleaned_text": "hello", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 0 + + +def test_empty_text_sets_skip_me_without_calling_model() -> None: + stage = FastTextLIDStage(model_path="/fake/model.bin") + mock_lid = MagicMock() + stage._lid = mock_lid + task = AudioTask(data={"cleaned_text": " ", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + mock_lid.score_document.assert_not_called() + + +def test_preserves_existing_skip_me_one() -> None: + stage = _make_stage("EN", 0.95) + task = AudioTask(data={"cleaned_text": "hello", "skip_me": 1}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_invalid_model_path_raises() -> None: + stage = FastTextLIDStage(model_path="/does/not/exist.bin") + with pytest.raises(ValueError, match="not a valid file path"): + stage._resolve_model_path() + + +def test_non_string_text_returns_task_unchanged() -> None: + stage = _make_stage("EN", 0.95) + task = AudioTask(data={"cleaned_text": None, "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 0 + + +def test_requires_model_path() -> None: + with pytest.raises(ValueError, match="model_path is required"): + FastTextLIDStage(model_path="") + + +def test_known_model_name_checks_cache(tmp_path: object) -> None: + stage = FastTextLIDStage(model_path="lid.176.ftz") + # Patch urlretrieve to avoid real download; patch cache dir to tmp_path + with ( + patch("nemo_curator.stages.audio.text_filtering.fasttext_lid._DEFAULT_CACHE_DIR", str(tmp_path)), + patch("urllib.request.urlretrieve") as mock_dl, + ): + mock_dl.side_effect = lambda url, path: open(path, "w").close() # noqa: SIM115 + resolved = stage._resolve_model_path() + assert resolved.endswith("lid.176.ftz") diff --git a/tests/stages/audio/text_filtering/test_finalize_fields.py b/tests/stages/audio/text_filtering/test_finalize_fields.py new file mode 100644 index 0000000000..f222c584e9 --- /dev/null +++ b/tests/stages/audio/text_filtering/test_finalize_fields.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.text_filtering.finalize_fields import FinalizeFieldsStage +from nemo_curator.tasks import AudioTask + + +def test_happy_path() -> None: + stage = FinalizeFieldsStage() + task = AudioTask( + data={ + "text": "original text", + "cleaned_text": "cleaned version", + "pnc": "pnc", + "itn": "noitn", + "timestamp": "notimestamp", + "audio_filepath": "/a.wav", + "duration": 3.5, + } + ) + result = stage.process(task) + assert result.data["v1_text"] == "original text" + assert result.data["text"] == "cleaned version" + assert "cleaned_text" not in result.data + assert "pnc" not in result.data + assert "itn" not in result.data + assert "timestamp" not in result.data + assert result.data["audio_filepath"] == "/a.wav" + assert result.data["duration"] == 3.5 + + +def test_missing_source_text_key_is_ignored() -> None: + stage = FinalizeFieldsStage() + task = AudioTask(data={"cleaned_text": "cleaned"}) + result = stage.process(task) + assert result.data["text"] == "cleaned" + assert "v1_text" not in result.data + + +def test_missing_drop_keys_are_ignored() -> None: + stage = FinalizeFieldsStage() + task = AudioTask(data={"text": "t", "cleaned_text": "c"}) + result = stage.process(task) # no pnc/itn/timestamp — should not raise + assert result.data["text"] == "c" + + +def test_custom_drop_keys() -> None: + stage = FinalizeFieldsStage(drop_keys=["custom_field", "another"]) + task = AudioTask(data={"text": "t", "cleaned_text": "c", "custom_field": "drop_me", "another": "also_drop"}) + result = stage.process(task) + assert "custom_field" not in result.data + assert "another" not in result.data + + +def test_other_fields_preserved() -> None: + stage = FinalizeFieldsStage() + task = AudioTask( + data={ + "text": "t", + "cleaned_text": "c", + "pred_text": "raw", + "skip_me": 0, + "shard_id": 42, + } + ) + result = stage.process(task) + assert result.data["pred_text"] == "raw" + assert result.data["skip_me"] == 0 + assert result.data["shard_id"] == 42 + + +def test_cleaned_text_removed_after_rename() -> None: + stage = FinalizeFieldsStage() + task = AudioTask(data={"text": "original", "cleaned_text": "clean"}) + result = stage.process(task) + assert "cleaned_text" not in result.data + assert result.data["text"] == "clean" diff --git a/tests/stages/audio/text_filtering/test_initialize_fields.py b/tests/stages/audio/text_filtering/test_initialize_fields.py new file mode 100644 index 0000000000..0084de9d42 --- /dev/null +++ b/tests/stages/audio/text_filtering/test_initialize_fields.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo_curator.stages.audio.text_filtering.initialize_fields import InitializeFieldsStage +from nemo_curator.tasks import AudioTask + + +def test_happy_path() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"pred_text": "hello world"}) + result = stage.process(task) + assert result.data["cleaned_text"] == "hello world" + assert result.data["skip_me"] == 0 + + +def test_original_pred_text_preserved() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"pred_text": "original"}) + result = stage.process(task) + assert result.data["pred_text"] == "original" + assert result.data["cleaned_text"] == "original" + + +def test_overwrites_existing_cleaned_text() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"pred_text": "new", "cleaned_text": "old"}) + result = stage.process(task) + assert result.data["cleaned_text"] == "new" + + +def test_custom_keys() -> None: + stage = InitializeFieldsStage(pred_text_key="asr_out", cleaned_text_key="norm_text", skip_me_key="drop") + task = AudioTask(data={"asr_out": "test text"}) + result = stage.process(task) + assert result.data["norm_text"] == "test text" + assert result.data["drop"] == 0 + + +def test_missing_pred_text_fails_validation() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"text": "has text but not pred_text"}) + assert stage.validate_input(task) is False + + +def test_validate_input_passes_with_pred_text() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"pred_text": "something"}) + assert stage.validate_input(task) is True + + +def test_process_batch_raises_on_missing_pred_text() -> None: + stage = InitializeFieldsStage() + task = AudioTask(data={"text": "no pred_text here"}) + with pytest.raises(ValueError, match="failed validation"): + stage.process_batch([task]) diff --git a/tests/stages/audio/text_filtering/test_regex_substitution.py b/tests/stages/audio/text_filtering/test_regex_substitution.py new file mode 100644 index 0000000000..0999feef37 --- /dev/null +++ b/tests/stages/audio/text_filtering/test_regex_substitution.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest +import yaml + +from nemo_curator.stages.audio.text_filtering.regex_substitution import RegexSubstitutionStage +from nemo_curator.tasks import AudioTask + + +def _write_rules(tmp_path: Path, rules: list[dict]) -> str: + p = tmp_path / "rules.yaml" + p.write_text(yaml.dump(rules), encoding="utf-8") + return str(p) + + +def test_applies_substitution(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": "\u2019", "repl": "'"}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "it\u2019s fine", "skip_me": 0}) + result = stage.process(task) + assert "'" in result.data["cleaned_text"] + assert result.data["skip_me"] == 0 + + +def test_empty_text_after_rules_sets_skip_me(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": r"\w+", "repl": ""}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "hello", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_whitespace_only_sets_skip_me(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": r"\S+", "repl": ""}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "hello world", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_non_empty_text_preserves_skip_me_zero(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": r"bad", "repl": "good"}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "bad word", "skip_me": 0}) + result = stage.process(task) + assert result.data["cleaned_text"] == "good word" + assert result.data["skip_me"] == 0 + + +def test_strips_extra_whitespace(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, []) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "hello world", "skip_me": 0}) + result = stage.process(task) + assert result.data["cleaned_text"] == "hello world" + + +def test_multiple_rules_applied_in_order(tmp_path: Path) -> None: + rules_path = _write_rules( + tmp_path, + [ + {"pattern": "\u2014", "repl": "-"}, # em-dash → hyphen + {"pattern": r"\s+", "repl": " "}, # collapse spaces (no-op after strip) + ], + ) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": "word\u2014word", "skip_me": 0}) + result = stage.process(task) + assert result.data["cleaned_text"] == "word-word" + + +def test_setup_called_lazily_when_skipped(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": "\u2019", "repl": "'"}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + # Intentionally do NOT call stage.setup() — process() must call it lazily. + task = AudioTask(data={"cleaned_text": "it\u2019s fine", "skip_me": 0}) + result = stage.process(task) + assert result.data["cleaned_text"] == "it's fine" + + +def test_non_string_text_returns_task_unchanged(tmp_path: Path) -> None: + rules_path = _write_rules(tmp_path, [{"pattern": r"\w+", "repl": ""}]) + stage = RegexSubstitutionStage(regex_params_yaml=rules_path) + stage.setup() + task = AudioTask(data={"cleaned_text": None, "skip_me": 0}) + result = stage.process(task) + assert result.data["cleaned_text"] is None + assert result.data["skip_me"] == 0 + + +def test_requires_regex_params_yaml() -> None: + with pytest.raises(ValueError, match="regex_params_yaml is required"): + RegexSubstitutionStage(regex_params_yaml="") diff --git a/tests/stages/audio/text_filtering/test_whisper_hallucination.py b/tests/stages/audio/text_filtering/test_whisper_hallucination.py new file mode 100644 index 0000000000..33d936e5cc --- /dev/null +++ b/tests/stages/audio/text_filtering/test_whisper_hallucination.py @@ -0,0 +1,125 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from nemo_curator.stages.audio.text_filtering.whisper_hallucination import WhisperHallucinationStage +from nemo_curator.tasks import AudioTask + + +def _make_stage(tmp_path: Path, phrases: list[str]) -> WhisperHallucinationStage: + p = tmp_path / "phrases.txt" + p.write_text("\n".join(phrases), encoding="utf-8") + stage = WhisperHallucinationStage(common_hall_file=str(p)) + stage.setup() + return stage + + +def test_clean_text_passes(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, ["lorem ipsum"]) + task = AudioTask(data={"cleaned_text": "the cat sat on the mat today", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 0 + + +def test_repeated_ngrams_sets_skip_me(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + task = AudioTask(data={"cleaned_text": "yes yes yes yes yes yes", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_long_word_absolute_threshold_sets_skip_me(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + long_word = "a" * 30 + task = AudioTask(data={"cleaned_text": f"the {long_word} here", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_long_word_relative_threshold_sets_skip_me(tmp_path: Path) -> None: + # "cat" (3) vs "verylongwordindeed" (18) — ratio (18-3)/3 = 5.0 >= 3.0 + stage = _make_stage(tmp_path, []) + task = AudioTask(data={"cleaned_text": "cat verylongwordindeed", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_frequent_phrase_sets_skip_me(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, ["Thank you 1297"]) + task = AudioTask(data={"cleaned_text": "Thank you", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_frequent_phrase_strips_punctuation(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, ["Thank you"]) + task = AudioTask(data={"cleaned_text": "Thank you.", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_frequent_phrase_strips_trailing_comma(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, ["Thank you"]) + task = AudioTask(data={"cleaned_text": "Thank you,", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_setup_called_lazily_when_skipped(tmp_path: Path) -> None: + p = tmp_path / "phrases.txt" + p.write_text("Thank you\n", encoding="utf-8") + stage = WhisperHallucinationStage(common_hall_file=str(p)) + # Intentionally do NOT call stage.setup() — process() must call it lazily. + task = AudioTask(data={"cleaned_text": "Thank you", "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_non_string_text_returns_task_unchanged(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + task = AudioTask(data={"cleaned_text": None, "skip_me": 0}) + result = stage.process(task) + assert result.data["skip_me"] == 0 + + +def test_preserves_existing_skip_me_one(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + task = AudioTask(data={"cleaned_text": "the cat sat on the mat", "skip_me": 1}) + result = stage.process(task) + assert result.data["skip_me"] == 1 + + +def test_empty_words_not_flagged_by_ngram(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + assert stage._repeated_ngrams([]) is False + + +def test_empty_words_not_flagged_by_long_word(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, []) + assert stage._long_word([]) is False + + +def test_phrases_file_strips_frequency_count(tmp_path: Path) -> None: + stage = _make_stage(tmp_path, ["Thank you 1297", "Amen -1", "Yeah 217"]) + assert "Thank you" in stage._phrases + assert "Amen" in stage._phrases + assert "Yeah" in stage._phrases + + +def test_requires_common_hall_file() -> None: + with pytest.raises(ValueError, match="common_hall_file is required"): + WhisperHallucinationStage(common_hall_file="") diff --git a/tutorials/audio/granary_v2_postprocessing/README.md b/tutorials/audio/granary_v2_postprocessing/README.md new file mode 100644 index 0000000000..ae626fd832 --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/README.md @@ -0,0 +1,179 @@ +# Granary v2 ASR Postprocessing Pipeline + +Postprocessing pipeline for Granary v2 ASR manifests. Reads JSONL manifests produced by ASR inference, cleans and filters transcriptions, and writes output manifests with a `skip_me` flag marking low-quality entries. + +## Pipeline stages + +| # | Stage | What it does | +|---|---|---| +| 1 | `ALMManifestReader` | Reads JSONL — one `AudioTask` per line | +| 2 | `InitializeFieldsStage` | Copies `pred_text` → `cleaned_text`; sets `skip_me = 0` | +| 3 | `RegexSubstitutionStage` | Normalizes `cleaned_text` (quotes, dashes, brackets, whitespace) | +| 4 | `WhisperHallucinationStage` | Sets `skip_me = 1` for repeated n-grams, long words, known hallucination phrases, or abnormal char/duration rates | +| 5 | `FastTextLIDStage` | Sets `skip_me = 1` for non-English or low-confidence language ID | +| 6 | `FinalizeFieldsStage` | Renames `text` → `v1_text`, promotes `cleaned_text` → `text`, drops `pnc`/`itn`/`timestamp` | +| 7 | `ALMManifestWriterStage` | Writes **all** entries to output — both clean (`skip_me=0`) and flagged (`skip_me=1`) | + +All entries are written to the output. Use `skip_me` downstream to filter or inspect flagged entries. + +## Output schema + +| Field | Description | +|---|---| +| `text` | Cleaned and normalized transcription | +| `v1_text` | Original reference text from the input manifest | +| `pred_text` | Raw ASR prediction (unchanged) | +| `skip_me` | `0` = clean, `1` = flagged by hallucination or LID filter | +| `audio_filepath` | Path to audio file | +| `duration` | Audio duration in seconds | +| All other original fields | Preserved as-is (except `pnc`, `itn`, `timestamp` which are dropped) | + +## Bundled config files + +| File | Purpose | +|---|---| +| `common.yaml` | Regex substitution rules applied to `cleaned_text` | +| `en.txt` | Known Whisper hallucination phrases (one per line) | + +Both are used by default — no need to pass them as arguments. + +## Running on Slurm + +### Quick start + +```bash +bash tutorials/audio/granary_v2_postprocessing/submit.sh \ + /path/to/output_root \ + /path/to/input_dir +``` + +`submit.sh` finds every `*.jsonl` under `input_dir` recursively, groups them into chunks of `MANIFESTS_PER_JOB` (default 128), and submits one Slurm job per chunk. All jobs run in parallel. The output directory structure mirrors the input: + +``` +input: input_dir/ytc/en2/manifest_0.jsonl +output: output_root/ytc/en2/manifest_0.jsonl +``` + +### Tuning chunk size + +The default is 128 manifests per job. Override with the `MANIFESTS_PER_JOB` environment variable: + +```bash +# Fewer, heavier jobs (large manifests) +MANIFESTS_PER_JOB=256 bash submit.sh /path/to/output /path/to/input + +# More, lighter jobs (small manifests, want more parallelism) +MANIFESTS_PER_JOB=32 bash submit.sh /path/to/output /path/to/input +``` + +For 6552 manifests: `MANIFESTS_PER_JOB=128` → 52 jobs, `MANIFESTS_PER_JOB=32` → 205 jobs. + +### Resuming interrupted runs + +Just resubmit the same command. Any manifest whose output file already exists and is non-empty is skipped automatically. Partially written files (from preempted jobs) are ignored and reprocessed. + +Check progress before resubmitting: + +```bash +INPUT=/path/to/input_dir +OUTPUT=/path/to/output_root + +TOTAL=$(find "$INPUT" -name "*.jsonl" | wc -l) +DONE=$(find "$OUTPUT" -name "*.jsonl" ! -name "*.tmp" | wc -l) +echo "Done: $DONE / $TOTAL (remaining: $((TOTAL - DONE)))" +``` + +### Sequential waves (dependent jobs) + +Pass multiple input directories — each wave starts after the previous one finishes: + +```bash +bash tutorials/audio/granary_v2_postprocessing/submit.sh \ + /path/to/output_root \ + /path/to/input_dir_batch_1 \ + /path/to/input_dir_batch_2 +``` + +Wave 2 waits for all Wave 1 jobs to finish (`afterany` dependency) before starting. + +### Single job (one directory) + +```bash +sbatch tutorials/audio/granary_v2_postprocessing/run.sh \ + /path/to/input_dir \ + /path/to/output_root +``` + +Processes all `*.jsonl` files under `input_dir` sequentially within one job. + +## Running locally / interactively + +```bash +export PYTHONPATH="/path/to/Curator:${PYTHONPATH:-}" + +python tutorials/audio/granary_v2_postprocessing/pipeline.py \ + --input_dir /path/to/input_dir \ + --output_dir /path/to/output_root \ + --fasttext_model /path/to/lid.176.ftz +``` + +To process specific manifests only: + +```bash +python tutorials/audio/granary_v2_postprocessing/pipeline.py \ + --input_dir /path/to/input_dir \ + --manifests /path/to/input_dir/corpus/manifest_0.jsonl \ + /path/to/input_dir/corpus/manifest_1.jsonl \ + --output_dir /path/to/output_root \ + --fasttext_model /path/to/lid.176.ftz +``` + +`--input_dir` is always the root used to compute relative output paths. All `--manifests` paths must be under `--input_dir`. + +## All arguments + +| Argument | Default | Description | +|---|---|---| +| `--input_dir` | required | Root input directory; also used as the anchor for mirroring output paths | +| `--output_dir` | required | Root output directory | +| `--manifests` | — | Process specific manifests instead of scanning all of `input_dir` (one or more paths, all must be under `--input_dir`) | +| `--fasttext_model` | `lid.176.ftz` | FastText LID model path (`lid.176.bin` or `lid.176.ftz`) | +| `--regex_yaml` | `common.yaml` | Regex substitution rules YAML | +| `--hall_phrases` | `en.txt` | Hallucination phrases file (one phrase per line) | +| `--target_lang` | `en` | Expected language code for LID | +| `--min_lang_prob` | `0.3` | Minimum FastText confidence to keep an entry | +| `--unique_words_threshold` | `0.4` | Unique-word ratio below which repeated n-grams are flagged | +| `--long_word_threshold` | `25` | Character length above which a word is flagged as abnormally long | +| `--long_word_rel_threshold` | `3.0` | Longest/second-longest word ratio for long-word detection | +| `--char_rate_threshold` | `4.0` | chars/s below which text is considered too sparse (long silence + few words) | +| `--max_char_rate` | `40.0` | chars/s above which text is considered impossibly dense (hallucinated sentence over short audio) | +| `--verbose` | off | Enable DEBUG logging (shows per-entry flagging reasons) | + +## Hallucination detection details + +`WhisperHallucinationStage` applies five checks to `cleaned_text`: + +| Check | Triggers when | +|---|---| +| Repeated n-grams | Unique-word ratio ≤ `unique_words_threshold` | +| Long word (absolute) | Any word ≥ `long_word_threshold` characters | +| Long word (relative) | Longest word is ≥ `long_word_rel_threshold` × second-longest | +| Phrase match | Text matches or starts with a phrase from `en.txt` (prefix match for phrases ≥ 8 chars) | +| Low char rate | `sum(word lengths) / duration ≤ char_rate_threshold` | +| High char rate | `sum(word lengths) / duration > max_char_rate` | + +Add new hallucination phrases to `en.txt`, one per line. + +## Stage implementation + +The filtering stages live in `nemo_curator/stages/audio/text_filtering/` and can be used in any custom pipeline: + +```python +from nemo_curator.stages.audio.text_filtering import ( + InitializeFieldsStage, + RegexSubstitutionStage, + WhisperHallucinationStage, + FastTextLIDStage, + FinalizeFieldsStage, +) +``` diff --git a/tutorials/audio/granary_v2_postprocessing/common.yaml b/tutorials/audio/granary_v2_postprocessing/common.yaml new file mode 100644 index 0000000000..950d932fe0 --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/common.yaml @@ -0,0 +1,31 @@ +- {"pattern": "’", "repl": "'"} +- {"pattern": "‘", "repl": "'"} +- {"pattern": "—", "repl": "-"} +- {"pattern": "–", "repl": "-"} +- {"pattern": "-", "repl": "-"} +- {"pattern": "_", "repl": " "} +- {"pattern": "——", "repl": "-"} +- {"pattern": "Ё", "repl": "Е"} +- {"pattern": "ё", "repl": "е"} + +- {"pattern": "♫", "repl": " "} +- {"pattern": "♪", "repl": " "} +- {"pattern": "♬", "repl": " "} +- {"pattern": "♩", "repl": " "} +- {"pattern": "♭", "repl": " "} +- {"pattern": '\|', "repl": " "} # : -> : +- {"pattern": ";", "repl": ","} + +- {"pattern": '\[[^\]]*\]', "repl": ""} # delete content inside [] +- {"pattern": ' ?\([^\)]+\)', "repl": ""} # delete content inside () +- {"pattern": ' ?{[^}]+}', "repl": ""} # delete content inside {} + +- {"pattern": "[^ !$%',-.0123456789;?ABCDEFGHIJKLMNOPQRSßTUVWXYŸZabcdefghijklmnopqrsẞtuvwxyÿz¡£¿ÀÁÂÃÄÅÆÇÈÉÊÌÍÎÑÒÓÔÕÖØÙÚÜÝàáâãäåæçèéêëìíîïñòóôõöøùúûüýĀāĂ㥹ĆćĊċČčĎďĐđĒēĖėĘęĚěĠġĢģĦħĪīĮįĶķĹĺĻļĽľŁłŃńŅņŇňŐőŒœŔŕŘřŚśŠšŤťŪūŮůŰűŲųŹźŻżŽžȘșȚțΆΈΉΌΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩάέήίαβγδεζηθικλμνξοπρστυφχψωϊόύώЁЄІЇАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїҐґ€₴₽/:]", "repl": " "} + +# keep capital letters, lowercase letters, and spaces, ?, !, ., ,, and ' only +- {"pattern": '\s+\.', "repl": "."} +- {"pattern": '\?+', "repl": "?"} +- {"pattern": '\.+', "repl": "."} +- {"pattern": ',+', "repl": ","} +- {"pattern": '!+', "repl": "!"} +- {"pattern": '\s+', "repl": " "} diff --git a/tutorials/audio/granary_v2_postprocessing/en.txt b/tutorials/audio/granary_v2_postprocessing/en.txt new file mode 100644 index 0000000000..f86cae3faa --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/en.txt @@ -0,0 +1,10 @@ +Thank you +you +Yeah +Check +Sayemashka +I I think +methods +Fifty-four +Amen +Thank you very much diff --git a/tutorials/audio/granary_v2_postprocessing/pipeline.py b/tutorials/audio/granary_v2_postprocessing/pipeline.py new file mode 100644 index 0000000000..801db59b7a --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/pipeline.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Granary v2 ASR postprocessing pipeline. + +Recursively finds all *.jsonl manifests under an input directory, applies text +cleaning and filtering, and writes output manifests mirroring the same +subdirectory structure under output_dir. + +Pipeline stages (per manifest): + 1. ALMManifestReader — read JSONL manifest → one AudioTask per line + 2. InitializeFieldsStage — copy pred_text → cleaned_text; skip_me = 0 + 3. RegexSubstitutionStage — apply regex normalization rules to cleaned_text + 4. WhisperHallucinationStage — flag Whisper hallucination patterns (sets skip_me=1) + 5. FastTextLIDStage — flag non-English or low-confidence transcriptions (sets skip_me=1) + 6. FinalizeFieldsStage — text → v1_text; cleaned_text → text; drop pnc/itn/timestamp + 7. ALMManifestWriterStage — write all entries (including flagged) to mirrored output path + +Usage:: + + python tutorials/audio/granary_v2_postprocessing/pipeline.py \\ + --input_dir /path/to/results_dir \\ + --output_dir /path/to/output_root \\ + --fasttext_model lid.176.ftz +""" + +import argparse +import os +import sys +from pathlib import Path + +from loguru import logger + +from nemo_curator.backends.xenna import XennaExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader +from nemo_curator.stages.audio.alm.alm_manifest_writer import ALMManifestWriterStage +from nemo_curator.stages.audio.text_filtering import ( + FastTextLIDStage, + FinalizeFieldsStage, + InitializeFieldsStage, + RegexSubstitutionStage, + WhisperHallucinationStage, +) + +_TUTORIAL_DIR = Path(__file__).parent +_DEFAULT_REGEX_YAML = str(_TUTORIAL_DIR / "common.yaml") +_DEFAULT_HALL_PHRASES = str(_TUTORIAL_DIR / "en.txt") + + +def _find_manifests(input_dir: str) -> list[str]: + """Return all *.jsonl files found recursively under input_dir, sorted.""" + return sorted(str(p) for p in Path(input_dir).rglob("*.jsonl")) + + +def _compute_output_paths(manifest_paths: list[str], input_dir: str, output_dir: str) -> dict[str, str]: + """Mirror each manifest path from input_dir into output_dir, preserving relative structure. + + Example:: + + input_dir: /data/results_large_scale_6 + input: /data/results_large_scale_6/corpus_a/manifest_0.jsonl + output_dir: /out + → /out/corpus_a/manifest_0.jsonl + """ + root = Path(input_dir) + return {str(p): str(Path(output_dir) / Path(p).relative_to(root)) for p in manifest_paths} + + +def _create_pipeline(manifest_path: str, output_path: str, args: argparse.Namespace) -> Pipeline: + pipeline = Pipeline( + name="Granary_v2_postprocessing", + description=( + "Text cleaning, hallucination detection, and language ID filtering " + "for Granary v2 ASR manifests." + ), + ) + pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path)) + pipeline.add_stage(InitializeFieldsStage()) + pipeline.add_stage(RegexSubstitutionStage(regex_params_yaml=args.regex_yaml)) + pipeline.add_stage( + WhisperHallucinationStage( + common_hall_file=args.hall_phrases, + unique_words_threshold=args.unique_words_threshold, + long_word_threshold=args.long_word_threshold, + long_word_rel_threshold=args.long_word_rel_threshold, + char_rate_threshold=args.char_rate_threshold, + max_char_rate=args.max_char_rate, + ) + ) + pipeline.add_stage( + FastTextLIDStage( + model_path=args.fasttext_model, + target_lang=args.target_lang, + min_lang_prob=args.min_lang_prob, + ) + ) + pipeline.add_stage(FinalizeFieldsStage()) + pipeline.add_stage(ALMManifestWriterStage(output_path=output_path)) + return pipeline + + +def main(args: argparse.Namespace) -> None: + logger.remove() + logger.add(sys.stderr, level="DEBUG" if args.verbose else "INFO") + + if args.manifests: + manifest_paths = args.manifests + logger.info(f"Processing {len(manifest_paths)} specified manifest(s)") + else: + manifest_paths = _find_manifests(args.input_dir) + if not manifest_paths: + logger.error(f"No *.jsonl files found under {args.input_dir}") + sys.exit(1) + logger.info(f"Found {len(manifest_paths)} manifest(s) under {args.input_dir}") + + output_map = _compute_output_paths(manifest_paths, args.input_dir, args.output_dir) + for src, dst in output_map.items(): + logger.info(f" {src}") + logger.info(f" → {dst}") + + executor = XennaExecutor() + + n_done = n_skipped = 0 + for i, (manifest_path, output_path) in enumerate(output_map.items(), 1): + logger.info(f"\n[{i}/{len(output_map)}] {manifest_path}") + + # Skip manifests whose output already exists and is non-empty. + # This makes reruns safe: preempted or partially-run jobs can be + # resubmitted and only the missing manifests will be processed. + if os.path.exists(output_path) and os.path.getsize(output_path) > 0: + logger.info(f" Already done, skipping → {output_path}") + n_skipped += 1 + continue + + # Write to a .tmp file first, then rename atomically on success. + # A preempted run leaves only the .tmp file, which is ignored on + # the next run (not a valid .jsonl), so the manifest is reprocessed. + tmp_path = output_path + ".tmp" + if os.path.exists(tmp_path): + os.remove(tmp_path) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + pipeline = _create_pipeline(manifest_path, tmp_path, args) + if args.verbose: + logger.debug(pipeline.describe()) + pipeline.run(executor) + os.rename(tmp_path, output_path) + logger.info(f" Written → {output_path}") + n_done += 1 + + logger.info( + f"\nDone. processed={n_done}, skipped={n_skipped} " + f"(total={len(output_map)}) → {args.output_dir}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Granary v2 ASR postprocessing pipeline.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Root input directory used to compute mirrored output paths.", + ) + parser.add_argument( + "--manifests", + type=str, + nargs="+", + default=None, + help="Process specific manifests instead of scanning all of input_dir. " + "All paths must be under input_dir so output paths can be computed correctly.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Root output directory. Input manifest paths are mirrored here.", + ) + parser.add_argument( + "--fasttext_model", + type=str, + default="lid.176.ftz", + help="FastText LID model: local path or known name (lid.176.bin / lid.176.ftz).", + ) + parser.add_argument( + "--regex_yaml", + type=str, + default=_DEFAULT_REGEX_YAML, + help="Path to regex substitution rules YAML.", + ) + parser.add_argument( + "--hall_phrases", + type=str, + default=_DEFAULT_HALL_PHRASES, + help="Path to hallucination phrases text file.", + ) + parser.add_argument( + "--target_lang", + type=str, + default="en", + help="Expected language code for LID filtering.", + ) + parser.add_argument( + "--min_lang_prob", + type=float, + default=0.3, + help="Minimum FastText language probability to keep an entry.", + ) + parser.add_argument( + "--unique_words_threshold", + type=float, + default=0.4, + help="Unique-word ratio threshold for repeated n-gram hallucination detection.", + ) + parser.add_argument( + "--long_word_threshold", + type=int, + default=25, + help="Absolute character length above which a word is flagged as abnormally long.", + ) + parser.add_argument( + "--long_word_rel_threshold", + type=float, + default=3.0, + help="Relative length ratio (longest/second-longest) for long-word hallucination detection.", + ) + parser.add_argument( + "--char_rate_threshold", + type=float, + default=4.0, + help="Max chars/s below which text is considered too sparse (low char-rate hallucination).", + ) + parser.add_argument( + "--max_char_rate", + type=float, + default=40.0, + help="Min chars/s above which text is considered impossibly dense (high char-rate hallucination).", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable DEBUG-level logging.", + ) + main(parser.parse_args()) diff --git a/tutorials/audio/granary_v2_postprocessing/run.sh b/tutorials/audio/granary_v2_postprocessing/run.sh new file mode 100755 index 0000000000..b7ffd3da0d --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/run.sh @@ -0,0 +1,44 @@ +#!/bin/bash +#SBATCH -A llmservice_nemo_speechlm +#SBATCH -p batch +#SBATCH --job-name=granary-v2-postprocess +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=0 +#SBATCH -t 4:00:00 +#SBATCH --output=/lustre/fsw/convai_convaird_nemo-speech/users/ntadevosyan/projects/granary-v2-asr/Curator/new_logs/%j_postprocess.out +#SBATCH --error=/lustre/fsw/convai_convaird_nemo-speech/users/ntadevosyan/projects/granary-v2-asr/Curator/new_logs/%j_postprocess.err +#SBATCH --container-image=/lustre/fsw/llmservice_nemo_speechlm/users/nkoluguri/containers/curator-nightly-lhotse.sqsh +#SBATCH --container-mounts=/lustre/fsw/convai_convaird_nemo-speech:/lustre/fsw/convai_convaird_nemo-speech,/lustre/fsw/llmservice_nemo_speechlm:/lustre/fsw/llmservice_nemo_speechlm + +set -euo pipefail + +CURATOR_DIR="/lustre/fsw/convai_convaird_nemo-speech/users/ntadevosyan/projects/granary-v2-asr/Curator" +FASTTEXT_MODEL="/lustre/fsw/convai_convaird_nemo-speech/users/ntadevosyan/projects/granary-v2-asr/postprocess/fleurs/cache/lid.176.ftz" +INPUT_DIR="${1:?Usage: sbatch run.sh [extra pipeline args]}" +OUTPUT_DIR="${2:?}" +shift 2 +EXTRA_ARGS=("$@") # e.g. --manifests /path/to/shard_0.jsonl + +# INPUT_ROOT, if set, is used as --input_dir (path anchor for output mirroring). +# This lets submit_benchmarks.sh scan a benchmark subdir while still mirroring +# the full path hierarchy (e.g. ytc/en9/manifest.jsonl) into output_dir. +INPUT_ROOT="${INPUT_ROOT:-${INPUT_DIR}}" + +echo "Input dir : ${INPUT_DIR}" +echo "Input root : ${INPUT_ROOT}" +echo "Output dir : ${OUTPUT_DIR}" +echo "Node : $(hostname)" +echo "Started : $(date)" + +export PYTHONPATH="${CURATOR_DIR}:${PYTHONPATH:-}" + +cd "${CURATOR_DIR}" +python tutorials/audio/granary_v2_postprocessing/pipeline.py \ + --input_dir "${INPUT_ROOT}" \ + --output_dir "${OUTPUT_DIR}" \ + --fasttext_model "${FASTTEXT_MODEL}" \ + "${EXTRA_ARGS[@]}" + +echo "Finished : $(date)" diff --git a/tutorials/audio/granary_v2_postprocessing/submit.sh b/tutorials/audio/granary_v2_postprocessing/submit.sh new file mode 100755 index 0000000000..0cc5ea1bee --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/submit.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# Submit Slurm jobs for the postprocessing pipeline. +# Manifests are grouped into chunks of MANIFESTS_PER_JOB so the number of +# submitted jobs stays manageable even for large datasets. +# +# Chunks where every output already exists are skipped — no job submitted. +# This makes resuming efficient: only incomplete work is requeued. +# +# Usage: +# bash submit.sh [input_dir_2 ...] +# +# Tune chunk size and CPUs via environment variables: +# MANIFESTS_PER_JOB=32 CPUS_PER_JOB=64 bash submit.sh +# +# Multiple input dirs are submitted as sequential waves: +# wave N+1 starts only after every job in wave N finishes (afterany). + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUN_SCRIPT="${SCRIPT_DIR}/run.sh" + +MANIFESTS_PER_JOB="${MANIFESTS_PER_JOB:-8}" +CPUS_PER_JOB="${CPUS_PER_JOB:-8}" + +# INPUT_ROOT overrides the path-anchor used for output mirroring. +# Set by submit_benchmarks.sh so that benchmark subdirs (e.g. .../ytc) still +# produce correctly nested output paths (e.g. output/ytc/en9/manifest.jsonl). +# When calling submit.sh directly, leave unset — INPUT_DIR is used as the anchor. +INPUT_ROOT="${INPUT_ROOT:-}" + +OUTPUT_DIR="${1:?Usage: bash submit.sh [input_dir_2 ...]}" +shift +INPUT_DIRS=("$@") + +if [[ ${#INPUT_DIRS[@]} -eq 0 ]]; then + echo "Error: at least one input_dir is required." >&2 + exit 1 +fi + +mkdir -p "${OUTPUT_DIR}" + +PREV_WAVE_IDS=() + +for INPUT_DIR in "${INPUT_DIRS[@]}"; do + mapfile -t MANIFESTS < <(find "${INPUT_DIR}" -name "*.jsonl" | sort) + + if [[ ${#MANIFESTS[@]} -eq 0 ]]; then + echo "Warning: no *.jsonl found under ${INPUT_DIR}, skipping." >&2 + continue + fi + + DEPEND_FLAG="" + if [[ ${#PREV_WAVE_IDS[@]} -gt 0 ]]; then + DEP_LIST=$(IFS=:; echo "${PREV_WAVE_IDS[*]}") + DEPEND_FLAG="--dependency=afterany:${DEP_LIST}" + fi + + N_JOBS=$(( (${#MANIFESTS[@]} + MANIFESTS_PER_JOB - 1) / MANIFESTS_PER_JOB )) + echo "Wave: ${INPUT_DIR}" + echo " Manifests : ${#MANIFESTS[@]} | per job : ${MANIFESTS_PER_JOB} | max jobs : ${N_JOBS}" + [[ -n "${DEPEND_FLAG}" ]] && echo " Depends on: ${PREV_WAVE_IDS[*]}" + + CURRENT_WAVE_IDS=() + n_submitted=0 + n_skipped=0 + chunk=() + # Use INPUT_ROOT as the path anchor if set, otherwise fall back to INPUT_DIR + ROOT_DIR="${INPUT_ROOT:-${INPUT_DIR}}" + + for MANIFEST in "${MANIFESTS[@]}"; do + chunk+=("${MANIFEST}") + + if [[ ${#chunk[@]} -eq ${MANIFESTS_PER_JOB} ]]; then + # Check if every output in this chunk already exists and is non-empty + all_done=$(python3 - "${ROOT_DIR}" "${OUTPUT_DIR}" "${chunk[@]}" <<'PYEOF' +import sys, os +input_dir, output_dir = sys.argv[1], sys.argv[2] +for m in sys.argv[3:]: + out = os.path.join(output_dir, os.path.relpath(m, input_dir)) + if not os.path.isfile(out) or os.path.getsize(out) == 0: + print("no"); sys.exit(0) +print("yes") +PYEOF +) + if [[ "${all_done}" == "yes" ]]; then + (( n_skipped++ )) || true + else + JOB_ID=$(sbatch \ + ${DEPEND_FLAG} \ + --cpus-per-task="${CPUS_PER_JOB}" \ + --parsable \ + "${RUN_SCRIPT}" "${INPUT_DIR}" "${OUTPUT_DIR}" --manifests "${chunk[@]}") + CURRENT_WAVE_IDS+=("${JOB_ID}") + echo " ${JOB_ID} ← ${#chunk[@]} manifests" + (( n_submitted++ )) || true + fi + chunk=() + fi + done + + # Handle remaining manifests + if [[ ${#chunk[@]} -gt 0 ]]; then + all_done=$(python3 - "${ROOT_DIR}" "${OUTPUT_DIR}" "${chunk[@]}" <<'PYEOF' +import sys, os +input_dir, output_dir = sys.argv[1], sys.argv[2] +for m in sys.argv[3:]: + out = os.path.join(output_dir, os.path.relpath(m, input_dir)) + if not os.path.isfile(out) or os.path.getsize(out) == 0: + print("no"); sys.exit(0) +print("yes") +PYEOF +) + if [[ "${all_done}" == "yes" ]]; then + (( n_skipped++ )) || true + else + JOB_ID=$(sbatch \ + ${DEPEND_FLAG} \ + --cpus-per-task="${CPUS_PER_JOB}" \ + --parsable \ + "${RUN_SCRIPT}" "${INPUT_DIR}" "${OUTPUT_DIR}" --manifests "${chunk[@]}") + CURRENT_WAVE_IDS+=("${JOB_ID}") + echo " ${JOB_ID} ← ${#chunk[@]} manifests (last chunk)" + (( n_submitted++ )) || true + fi + fi + + echo " Submitted : ${n_submitted} | Already done (skipped) : ${n_skipped}" + PREV_WAVE_IDS=("${CURRENT_WAVE_IDS[@]}") + echo "" +done + +if [[ ${#PREV_WAVE_IDS[@]} -gt 0 ]]; then + echo "All waves submitted." + echo "Monitor : squeue -u ${USER}" + echo "Job IDs : ${PREV_WAVE_IDS[*]}" +else + echo "Nothing to submit — all manifests already done." +fi diff --git a/tutorials/audio/granary_v2_postprocessing/submit_benchmarks.sh b/tutorials/audio/granary_v2_postprocessing/submit_benchmarks.sh new file mode 100755 index 0000000000..607ca08a37 --- /dev/null +++ b/tutorials/audio/granary_v2_postprocessing/submit_benchmarks.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Submit postprocessing jobs per benchmark with configurable per-benchmark chunk sizes. +# +# Scans top-level subdirectories of and submits each benchmark +# independently via submit.sh. Per-benchmark MANIFESTS_PER_JOB overrides are +# defined in BENCHMARK_CHUNKS below — all others get the DEFAULT (128). +# +# Usage: +# bash submit_benchmarks.sh +# +# Override defaults via env: +# DEFAULT_MANIFESTS_PER_JOB=64 CPUS_PER_JOB=32 bash submit_benchmarks.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SUBMIT_SCRIPT="${SCRIPT_DIR}/submit.sh" + +OUTPUT_DIR="${1:?Usage: bash submit_benchmarks.sh }" +INPUT_DIR="${2:?}" + +# -------------------------------------------------------------------------- +# Per-benchmark chunk size overrides. +# Add/edit entries here: ["benchmark_name"]=N +# -------------------------------------------------------------------------- +declare -A BENCHMARK_CHUNKS=( + ["ytc"]=8 +) + +DEFAULT_MANIFESTS_PER_JOB="${DEFAULT_MANIFESTS_PER_JOB:-8}" +export CPUS_PER_JOB="${CPUS_PER_JOB:-16}" + +# -------------------------------------------------------------------------- + +mapfile -t BENCH_DIRS < <(find "${INPUT_DIR}" -mindepth 1 -maxdepth 1 -type d | sort) + +if [[ ${#BENCH_DIRS[@]} -eq 0 ]]; then + echo "Error: no subdirectories found under ${INPUT_DIR}" >&2 + exit 1 +fi + +echo "Output dir : ${OUTPUT_DIR}" +echo "Input dir : ${INPUT_DIR}" +echo "Benchmarks : ${#BENCH_DIRS[@]}" +echo "CPUs/job : ${CPUS_PER_JOB}" +echo "" + +for BENCH_DIR in "${BENCH_DIRS[@]}"; do + BENCH_NAME=$(basename "${BENCH_DIR}") + + if [[ -v BENCHMARK_CHUNKS["${BENCH_NAME}"] ]]; then + CHUNKS="${BENCHMARK_CHUNKS[${BENCH_NAME}]}" + else + CHUNKS="${DEFAULT_MANIFESTS_PER_JOB}" + fi + + # Check if every manifest in this benchmark already has a non-empty output. + # If so, skip the benchmark entirely — no jobs submitted. + bench_done=$(python3 - "${INPUT_DIR}" "${OUTPUT_DIR}" "${BENCH_DIR}" <<'PYEOF' +import sys, os +from pathlib import Path +input_dir, output_dir, bench_dir = sys.argv[1], sys.argv[2], sys.argv[3] +manifests = list(Path(bench_dir).rglob("*.jsonl")) +if not manifests: + print("no"); sys.exit(0) +for m in manifests: + out = os.path.join(output_dir, os.path.relpath(str(m), input_dir)) + if not os.path.isfile(out) or os.path.getsize(out) == 0: + print("no"); sys.exit(0) +print("yes") +PYEOF +) + + if [[ "${bench_done}" == "yes" ]]; then + echo ">>> ${BENCH_NAME} — already done, skipping" + continue + fi + + echo ">>> ${BENCH_NAME} (MANIFESTS_PER_JOB=${CHUNKS})" + # INPUT_ROOT tells submit.sh (and the Slurm job) to use the original root + # dir as the path anchor, so output mirrors full hierarchy: ytc/en9/manifest.jsonl + MANIFESTS_PER_JOB="${CHUNKS}" INPUT_ROOT="${INPUT_DIR}" bash "${SUBMIT_SCRIPT}" "${OUTPUT_DIR}" "${BENCH_DIR}" +done