diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 90f97fdedfa8..dabdf4dff469 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -135,6 +135,8 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('eou_silence_rate', ''), metrics.get('eou_noise_rate', ''), metrics.get('eou_error_rate', ''), + metrics.get('katakana_cer_filewise_avg', ''), + metrics.get('katakana_cer_cumulative', ''), ] with open(csv_path, "a") as f: f.write(",".join(str(v) for v in values) + "\n") @@ -151,10 +153,7 @@ def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: return metrics_mean_ci -def filter_datasets( - dataset_meta_info: dict, - datasets: Optional[List[str]], -) -> List[str]: +def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> List[str]: """Select datasets from the dataset meta info.""" if datasets is None: # Dataset filtering not specified, return all datasets. @@ -233,7 +232,8 @@ def run_inference_and_evaluation( "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," - "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" + "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate," + "katakana_cer_filewise_avg,katakana_cer_cumulative" ) for dataset in datasets: @@ -241,11 +241,13 @@ def run_inference_and_evaluation( meta = dataset_meta_info[dataset] manifest_records = read_manifest(meta['manifest_path']) - language = meta.get('whisper_language', 'en') + # `language` drives all language-specific eval logic (ASR target_lang, Whisper prompt, + # text normalization). `whisper_language` is kept as a fallback for legacy evalsets. + language = meta.get('language', meta.get('whisper_language', 'en')) # Prepare dataset metadata (remove evaluation-specific keys) dataset_meta_for_dl = copy.deepcopy(meta) - for key in ["whisper_language", "load_cached_codes_if_available"]: + for key in ["language", "whisper_language", "load_cached_codes_if_available"]: dataset_meta_for_dl.pop(key, None) # Setup output directories @@ -492,10 +494,14 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: help='Path to dataset configuration JSON file', ) data_group.add_argument( - '--datasets_base_path', - type=Path, - default=None, - help='Optional base path that paths in the "datasets_json_path" file are relative to', + '--root', + type=str, + default='', + help=( + 'Root directory the evalset relative paths are resolved against. ' + 'Each entry\'s "manifest_path" and "audio_dir" are joined onto this root. ' + 'Defaults to empty (treat paths as absolute / cwd-relative).' + ), ) data_group.add_argument( '--datasets', @@ -601,11 +607,6 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: default=None, help='Override path to the phoneme tokenizer file (overrides the path stored in the checkpoint config)', ) - group.add_argument( - '--disable_cas_for_context_text', - action='store_true', - help='Skip CAS embeddings for context text when loading legacy EasyMagpieTTS models', - ) def create_argument_parser() -> argparse.ArgumentParser: @@ -667,9 +668,13 @@ def main(argv=None): if args.deterministic: seed_all(seed=9) - dataset_meta_info = load_evalset_config( - config_path=args.datasets_json_path, dataset_base_path=args.datasets_base_path - ) + dataset_meta_info = load_evalset_config(args.datasets_json_path) + # Resolve relative evalset paths against --root so the checked-in config stays portable. + if args.root: + for _meta in dataset_meta_info.values(): + for _key in ("manifest_path", "audio_dir"): + if _meta.get(_key): + _meta[_key] = os.path.join(args.root, _meta[_key]) datasets = filter_datasets(dataset_meta_info, args.datasets) logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") @@ -721,7 +726,6 @@ def main(argv=None): legacy_text_conditioning=args.legacy_text_conditioning, hparams_from_wandb=args.hparams_file_from_wandb, phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), - disable_cas_for_context_text=args.disable_cas_for_context_text, ) # Load model @@ -764,7 +768,6 @@ def main(argv=None): legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), - disable_cas_for_context_text=args.disable_cas_for_context_text, ) # Load model diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 9d2540a694ff..f9d44b20700d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -18,12 +18,13 @@ import json import os import pprint +import re +import string import tempfile import time from collections import Counter from functools import partial -from pathlib import Path -from typing import Optional, Union +from typing import Union import librosa import numpy as np @@ -33,9 +34,21 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate_detail + +try: + from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models_prompt import ( + EncDecHybridRNNTCTCBPEModelWithPrompt, + HybridRNNTCTCPromptTranscribeConfig, + ) + + _PARAKEET_PROMPT_ASR_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + EncDecHybridRNNTCTCBPEModelWithPrompt = None # type: ignore + HybridRNNTCTCPromptTranscribeConfig = None # type: ignore + _PARAKEET_PROMPT_ASR_AVAILABLE = False + from nemo.collections.tts.metrics.eou_classifier import EoUClassification, EoUClassifier, EoUType from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance -from nemo.collections.tts.parts.utils.tts_dataset_utils import get_text_processor from nemo.utils import logging # Optional import for UTMOSv2 (audio quality metric) @@ -55,8 +68,12 @@ FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', + 'katakana_cer', + 'gt_katakana', + 'pred_katakana', 'pred_context_ssim', 'pred_text', + 'gt_audio_text', 'gt_text', 'gt_audio_filepath', 'pred_audio_filepath', @@ -68,37 +85,13 @@ ] -def load_evalset_config(config_path: Optional[str] = None, dataset_base_path: Optional[Path] = None) -> dict: +def load_evalset_config(config_path: str = None) -> dict: """Load dataset meta info from JSON config file.""" if config_path is None or not os.path.exists(config_path): raise ValueError("No dataset_json_path provided, please provide a valid path to the evalset config file.") - logging.info(f"Loading evalset config from {config_path}") with open(config_path, 'r') as f: - dataset_meta_info = json.load(f) - - # Validate that all evaluation datasets exist - for dataset_name, info in dataset_meta_info.items(): - manifest_path = Path(info["manifest_path"]) - audio_dir = Path(info["audio_dir"]) - - if dataset_base_path: - # Replace relative paths with absolute paths where appropriate - if not manifest_path.is_absolute(): - manifest_path = dataset_base_path / manifest_path - info["manifest_path"] = str(manifest_path) - - if not audio_dir.is_absolute(): - audio_dir = dataset_base_path / audio_dir - info["audio_dir"] = str(audio_dir) - - if not manifest_path.exists(): - raise ValueError(f"Manifest does not exist for dataset {dataset_name}: {manifest_path}") - - if not audio_dir.exists(): - raise ValueError(f"Audio directory does not exist for dataset {dataset_name}: {audio_dir}") - - return dataset_meta_info + return json.load(f) def _resolve_path(audio_dir, path): @@ -150,22 +143,124 @@ def read_manifest(manifest_path): return records -def transcribe_with_nemo_asr_batched(asr_model, audio_paths, batch_size=8, label=""): - """Transcribe multiple audio files with a NeMo ASR model in batches. Returns list of transcriptions (one per path).""" +def process_text(input_text): + # Remove Arabic tashkeel (diacritics/harakat) + input_text = re.sub(r'[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06DC\u06DF-\u06E4\u06E7\u06E8\u06EA-\u06ED]', '', input_text) + # Remove Arabic punctuation + input_text = re.sub(r'[،؟؛«»٪٫٬]', '', input_text) + # Remove Hindi-specific punctuation (danda, double danda) + input_text = re.sub(r'[।॥॰]', '', input_text) + # Remove Mandarin-specific punctuation + input_text = re.sub(r'[,。!?;:""''()【】《》〈〉「」『』、…·~—–\u3000]', '', input_text) + # Remove Japanese-specific punctuation + input_text = re.sub(r'[。、!?「」『』()【】〔〕・…‥〜ー\u3000\u30FB]', '', input_text) + + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + + +_PYOPENJTALK = None + + +def text_to_katakana(text: str) -> str: + """Convert Japanese text to its Katakana reading via pyopenjtalk (lazy-imported). + + Used for an additional, reading-based Japanese CER metric that is robust to + kanji/kana spelling variation between the reference and the ASR hypothesis. + Returns "" on empty input or if pyopenjtalk is unavailable/fails. + """ + global _PYOPENJTALK + if not text: + return "" + if _PYOPENJTALK is None: + try: + import pyopenjtalk + + _PYOPENJTALK = pyopenjtalk + except Exception as e: # noqa: BLE001 + logging.warning(f"pyopenjtalk not available; skipping katakana CER: {e}") + _PYOPENJTALK = False + if _PYOPENJTALK is False: + return "" + try: + return _PYOPENJTALK.g2p(text, kana=True).strip() + except Exception as e: # noqa: BLE001 + logging.warning(f"pyopenjtalk failed for '{text[:40]}': {e}") + return "" + + +def eval_language_to_parakeet_target_lang(lang: str) -> str: + """Map evalset ``whisper_language`` (or HF-style codes) to Parakeet prompt ``target_lang`` IDs.""" + if not lang: + return "en-US" + lang = lang.strip() + # BCP-47 style already (e.g. pt-BR, zh-CN, en-US) + if "-" in lang and len(lang) >= 4: + return lang + return { + "en": "en-US", + "ar": "ar", + "ko": "ko-KR", + "hi": "hi-IN", + "zh": "zh-CN", + "it": "it-IT", + "es": "es-ES", + "de": "de-DE", + "fr": "fr-FR", + "ja": "ja-JP", + }.get(lang, lang) + + +def transcribe_with_nemo_asr_batched(asr_model, audio_paths, batch_size=8, label="", eval_language="en"): + """Transcribe with a NeMo ASR model. + + Parakeet multilingual **prompt** checkpoints (``EncDecHybridRNNTCTCBPEModelWithPrompt``) require + ``HybridRNNTCTCPromptTranscribeConfig`` with a valid ``target_lang``; plain ``transcribe()`` is not enough. + """ + use_prompt = ( + _PARAKEET_PROMPT_ASR_AVAILABLE + and EncDecHybridRNNTCTCBPEModelWithPrompt is not None + and isinstance(asr_model, EncDecHybridRNNTCTCBPEModelWithPrompt) + ) + target_lang = eval_language_to_parakeet_target_lang(eval_language) + all_transcriptions = [] for start in range(0, len(audio_paths), batch_size): batch_paths = audio_paths[start : start + batch_size] try: with torch.inference_mode(): - batch_results = asr_model.transcribe(batch_paths, batch_size=len(batch_paths), use_lhotse=False) + if use_prompt and HybridRNNTCTCPromptTranscribeConfig is not None: + cfg = HybridRNNTCTCPromptTranscribeConfig( + batch_size=len(batch_paths), + use_lhotse=False, + target_lang=target_lang, + ) + batch_results = asr_model.transcribe(batch_paths, override_config=cfg) + else: + batch_results = asr_model.transcribe(batch_paths, batch_size=len(batch_paths)) for r in batch_results: - all_transcriptions.append(r.text) + hyp_text = getattr(r, "text", None) + if hyp_text is None: + hyp_text = str(r) + all_transcriptions.append(process_text(hyp_text)) except Exception as e: logging.info("Error during batched ASR ({} audio): {}".format(label, e)) all_transcriptions.extend([""] * len(batch_paths)) return all_transcriptions - def transcribe_with_whisper_batched( whisper_model, whisper_processor, audio_paths, language, device, batch_size=8, label="" ): @@ -185,7 +280,7 @@ def transcribe_with_whisper_batched( with torch.inference_mode(): predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) transcriptions = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) - all_transcriptions.extend(transcriptions) + all_transcriptions.extend(process_text(t) for t in transcriptions) except Exception as e: logging.info("Error during batched Whisper ASR ({} audio): {}".format(label, e)) all_transcriptions.extend([""] * len(batch_paths)) @@ -253,26 +348,39 @@ def transcribed_batched( asr_batch_size, label="", ): - """Transcribe a list of audio files using NeMo ASR (English) or Whisper (other languages).""" - if language == "en": - texts = transcribe_with_nemo_asr_batched(asr_model, audio_paths, batch_size=asr_batch_size, label=label) - else: + """Transcribe a list of audio files using NeMo ASR (incl. Parakeet prompt) or Whisper.""" + if asr_model is not None: + texts = transcribe_with_nemo_asr_batched( + asr_model, + audio_paths, + batch_size=asr_batch_size, + label=label, + eval_language=language, + ) + elif whisper_model is not None: texts = transcribe_with_whisper_batched( - whisper_model, whisper_processor, audio_paths, language, device, batch_size=asr_batch_size, label=label + whisper_model, + whisper_processor, + audio_paths, + language, + device, + batch_size=asr_batch_size, + label=label, ) - + else: + raise ValueError("No ASR model loaded for evaluation (asr_model and whisper_model are both None)") return texts - def load_evaluation_models( language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", device="cuda" ): """Load ASR and speaker verification models used for evaluation. Args: - language: Language code. "en" uses a NeMo ASR model; other languages use Whisper. + language: Language / whisper hint for transcription (Parakeet prompt ``target_lang`` mapping, Whisper prompts). sv_model_type: Speaker verification model type ("wavlm" or "titanet"). - asr_model_name: Name of the NeMo ASR model (used only when language is "en"). + asr_model_name: NeMo ASR: local ``.nemo`` path (any language), or Hub id when ``language=="en"``; + otherwise Whisper is used unless a ``.nemo`` path is given. device: Device to place models on. Returns: @@ -286,10 +394,12 @@ def load_evaluation_models( 'feature_extractor': None, } - if language == "en": - if os.path.isfile(asr_model_name) and asr_model_name.endswith('.nemo'): - models['asr_model'] = nemo_asr.models.ASRModel.restore_from(restore_path=asr_model_name).to(device).eval() - elif asr_model_name.startswith("nvidia/") or asr_model_name in ["stt_en_conformer_transducer_large"]: + if asr_model_name.endswith(".nemo"): + models['asr_model'] = nemo_asr.models.ASRModel.restore_from( + restore_path=asr_model_name, + ).to(device).eval() + elif language == "en": + if asr_model_name.startswith("nvidia/") or asr_model_name in ["stt_en_conformer_transducer_large"]: models['asr_model'] = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name).to(device).eval() else: raise ValueError(f"ASR model {asr_model_name} not supported") @@ -407,9 +517,7 @@ def evaluate_dir( # 5. ASR transcription in batches logging.info(f"Doing batched ASR transcription with batch size {asr_batch_size}...") - # Transcribe predicted audios - text_processor = get_text_processor(language) pred_texts = transcribed_batched( audio_file_lists, language, @@ -420,7 +528,6 @@ def evaluate_dir( asr_batch_size, label="predicted", ) - pred_texts = [text_processor.process_text_for_wer(text) for text in pred_texts] # Transcribe ground truth audios if len(gt_audio_paths) > 0: gt_audio_texts = transcribed_batched( @@ -433,7 +540,6 @@ def evaluate_dir( asr_batch_size, label="ground truth", ) - gt_audio_texts = [text_processor.process_text_for_wer(text) for text in gt_audio_texts] else: gt_audio_texts = [None] * len(records) @@ -441,13 +547,11 @@ def evaluate_dir( gt_texts_processed = [] for record in records: if "original_text" in record: - text_field = 'original_text' + gt_texts_processed.append(process_text(record['original_text'])) elif 'normalized_text' in record: - text_field = 'normalized_text' + gt_texts_processed.append(process_text(record['normalized_text'])) else: - text_field = 'text' - processed_text = text_processor.process_text_for_wer(record[text_field]) - gt_texts_processed.append(processed_text) + gt_texts_processed.append(process_text(record['text'])) # 7. Batched EoU classification eou_results = None @@ -473,14 +577,36 @@ def evaluate_dir( gt_text = gt_texts_processed[ridx] + if language in ("zh", "zh-CN", "zh-TW"): + pred_text = pred_text.replace(" ", "") + gt_text = gt_text.replace(" ", "") + if gt_audio_text is not None: + gt_audio_text = gt_audio_text.replace(" ", "") + else: + pred_text = pred_text + gt_text = gt_text + detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) + # Japanese: additional reading-based CER on Katakana (pyopenjtalk g2p), robust to + # kanji/kana spelling differences between reference and ASR hypothesis. + gt_katakana = pred_katakana = None + katakana_cer = None + if language == "ja": + gt_katakana = text_to_katakana(gt_text) + pred_katakana = text_to_katakana(pred_text) + katakana_cer = word_error_rate_detail( + hypotheses=[pred_katakana], references=[gt_katakana], use_cer=True + )[0] + logging.info(f"{ridx} GT Text: {gt_text}") logging.info(f"{ridx} Pr Text: {pred_text}") # Format cer and wer to 2 decimal places logging.info(f"CER: {detailed_cer[0]:.4f} | WER: {detailed_wer[0]:.4f}") + pred_context_ssim = 0.0 + gt_context_ssim = 0.0 with torch.inference_mode(): extract_embedding_fn = partial( extract_embedding, @@ -570,6 +696,9 @@ def evaluate_dir( 'detailed_wer': detailed_wer, 'cer': detailed_cer[0], 'wer': detailed_wer[0], + 'katakana_cer': katakana_cer, + 'gt_katakana': gt_katakana, + 'pred_katakana': pred_katakana, 'pred_gt_ssim': pred_gt_ssim, 'pred_context_ssim': pred_context_ssim, 'gt_context_ssim': gt_context_ssim, @@ -717,6 +846,15 @@ def compute_global_metrics( avg_metrics['wer_cumulative'] = word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[ 0 ] + # Japanese reading-based CER (Katakana via pyopenjtalk); only present for ja datasets. + kata = [m for m in filewise_metrics if m.get('katakana_cer') is not None] + if kata: + avg_metrics['katakana_cer_filewise_avg'] = sum(m['katakana_cer'] for m in kata) / len(kata) + avg_metrics['katakana_cer_cumulative'] = word_error_rate_detail( + hypotheses=[m['pred_katakana'] for m in kata], + references=[m['gt_katakana'] for m in kata], + use_cer=True, + )[0] avg_metrics['ssim_pred_gt_avg'] = sum(m['pred_gt_ssim'] for m in filewise_metrics) / n avg_metrics['ssim_pred_context_avg'] = sum(m['pred_context_ssim'] for m in filewise_metrics) / n avg_metrics['ssim_gt_context_avg'] = sum(m['gt_context_ssim'] for m in filewise_metrics) / n @@ -777,12 +915,26 @@ def main(): parser.add_argument('--audio_dir', type=str, default=None) parser.add_argument('--generated_audio_dir', type=str, default=None) parser.add_argument('--whisper_language', type=str, default="en") - parser.add_argument('--evalset', type=str, default=None) + parser.add_argument( + '--datasets_json_path', + type=str, + default=None, + help='Path to evalset JSON (use with --evalset to fill manifest_path and audio_dir)', + ) + parser.add_argument( + '--evalset', + type=str, + default=None, + help='Dataset key inside --datasets_json_path', + ) args = parser.parse_args() if args.evalset is not None: - dataset_meta_info = load_evalset_config() - assert args.evalset in dataset_meta_info, f"Dataset '{args.evalset}' not found in evalset_config.json" + if not args.datasets_json_path: + parser.error("--datasets_json_path is required when using --evalset") + dataset_meta_info = load_evalset_config(args.datasets_json_path) + if args.evalset not in dataset_meta_info: + parser.error(f"Dataset '{args.evalset}' not found in {args.datasets_json_path}") args.manifest_path = dataset_meta_info[args.evalset]['manifest_path'] args.audio_dir = dataset_meta_info[args.evalset]['audio_dir']